diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc index 82aad853c3..39378c1652 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -128,7 +128,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { std::string convert_format; - size_t counter = 0; + const size_t counter = 0; for (const auto &iter : format_counter) { if (counter < iter.second) { convert_format = iter.first; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc index f373594f4a..4bcf904444 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc @@ -129,18 +129,22 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); - std::vector bn_outputs; - if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { - MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; - return nullptr; - } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kBatchNormRealInputNum + 1) { + if (cnode->size() < kBatchNormRealInputNum + 1) { MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum << ". The node should not be changed"; return nullptr; } + if (!GetBoolAttr(cnode, kAttrIsTraining)) { + MS_LOG(INFO) << "is training should be true if do fusion"; + return nullptr; + } + std::vector bn_outputs; + if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { + MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; + return nullptr; + } AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); std::vector bn_training_reduce_outputs; CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc index 429523bd8b..f95406e5e1 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc @@ -58,6 +58,10 @@ const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const auto input1 = GetAnfNodeByVar(equiv, input1_); auto input2 = GetAnfNodeByVar(equiv, input2_); auto sum = GetAnfNodeByVar(equiv, sum_var_); + if (!GetBoolAttr(sum, kAttrKeepDims)) { + MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion"; + return nullptr; + } auto prim = std::make_shared(kSoftmaxGradExtOpName); auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2}); diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 896ab71e09..11f5bc5083 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -722,5 +722,16 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { MS_EXCEPTION_IF_NULL(value_node2); return GetValue(value_node1->value()) < GetValue(value_node2->value()); } + +bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(INFO) << "node is not a cnode"; + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 1206fe2430..bb6da936cb 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -180,6 +180,9 @@ AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); // Compare tuple getitem's index, return bool[n1's index < n2's index] bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); + +// Get attr which is bool from cnode +bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/single_batch_norm_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/single_batch_norm_fission_test.py index 71db6c0b83..1ea31fba50 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/single_batch_norm_fission_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/single_batch_norm_fission_test.py @@ -17,7 +17,7 @@ from mindspore.ops import operations as P make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') -BatchNorm = P.BatchNorm() +BatchNorm = P.BatchNorm(is_training=True) BNTrainingReduce = Primitive('BNTrainingReduce') BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3') diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/softmax_grad_ext_fusion.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/softmax_grad_ext_fusion.py index 38b8be0493..52ba86aaa3 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/softmax_grad_ext_fusion.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/softmax_grad_ext_fusion.py @@ -16,7 +16,7 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P Mul = P.Mul() -ReduceSum = P.ReduceSum() +ReduceSum = P.ReduceSum(keep_dims=True) Sub = P.Sub() SoftmaxGradExt = Primitive('SoftmaxGradExt') MakeTuple = Primitive('make_tuple')