!1940 Fix when batchnorm has just 4 outputs, it will be not fission

Merge pull request !1940 from huanghui/single-batchnorm-fission-pass
pull/1940/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0624263496

@ -99,6 +99,7 @@ namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
}
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());

@ -24,7 +24,7 @@ namespace mindspore {
namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4};
constexpr size_t kBatchNormRealOutputNum = 5;
constexpr size_t kBatchNormLeastOutputNum = 1;
constexpr size_t kBatchNormRealInputNum = 3;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
bn_outputs->push_back(output);
output_num++;
}
return output_num == kBatchNormRealOutputNum;
return output_num > kBatchNormLeastOutputNum;
}
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {

Loading…
Cancel
Save