diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc index 2857e5e2b0..f6a941cc69 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc @@ -32,7 +32,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con return reduce_min; } -bool NeedOptmize(const TypeId &dtype, const std::vector &shape, const std::vector &axis) { +bool NeedOptimize(const TypeId &dtype, const std::vector &shape, const std::vector &axis) { if (dtype != kNumberTypeFloat32) { MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; return false; @@ -84,7 +84,7 @@ std::vector GetInferShape(const std::vector &shape, const std::v for (size_t item = 0; item < shape.size(); ++item) { if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { if (keep_dims) { - // If keep_dims is true, curretn dimesion set to 1 + // If keep_dims is true, current dimension set to 1 shape_first.push_back(1); } } else { @@ -110,28 +110,31 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN CheckCNodeInputSize(cnode, 2); auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); - if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) { - MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!"; + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) { + MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!"; return nullptr; } - auto axis = AnfAlgo::GetNodeAttr>(cnode, kAttrAxis); - if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { - MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!"; + auto axis_value = prim->GetAttr(kAttrAxis); + MS_EXCEPTION_IF_NULL(axis_value); + if (!axis_value->isa()) { return nullptr; } + auto axis = AnfAlgo::GetNodeAttr>(cnode, kAttrAxis); auto keep_dims = AnfAlgo::GetNodeAttr(cnode, kAttrKeepDims); - if (!NeedOptmize(dtype, shape, axis)) { + if (!NeedOptimize(dtype, shape, axis)) { MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); return nullptr; } // Create reduce_min1 CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); - std::vector axis_fisrt = CalFirstAxis(shape, axis); - std::vector shape_first = GetInferShape(shape, axis_fisrt, keep_dims); + std::vector axis_first = CalFirstAxis(shape, axis); + std::vector shape_first = GetInferShape(shape, axis_first, keep_dims); AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1); // Create reduce_min2 CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode);