From 36d1aadf1c87e248d714fa181997c392b20aa3b4 Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 9 Jun 2020 19:31:12 +0800 Subject: [PATCH] fix when Batchnorm's output is 0,1,2,4, fission doesn't work --- .../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc | 2 +- .../ascend/ir_fission/single_batch_norm_fission.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 97c256a9a8..27b99840df 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -99,6 +99,7 @@ namespace { void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { MS_EXCEPTION_IF_NULL(ir_fusion_pm); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc index 0e07d54b2c..5f01f2fab2 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace opt { namespace { const std::vector kOutputIndex{0, 1, 2, 3, 4}; -constexpr size_t kBatchNormRealOutputNum = 5; +constexpr size_t kBatchNormLeastOutputNum = 1; constexpr size_t kBatchNormRealInputNum = 3; bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { @@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s bn_outputs->push_back(output); output_num++; } - return output_num == kBatchNormRealOutputNum; + return output_num > kBatchNormLeastOutputNum; } AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {