From e87ac6525ea0835baccd03b0120435e213cefa0a Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Thu, 4 Jun 2020 19:56:01 +0800 Subject: [PATCH] Add batch norm fusion pattern for mix precision --- .../ascend/ascend_backend_optimization.cc | 3 +- .../ir_fusion/fused_batch_norm_fusion.cc | 25 +++++++++++++++- .../ir_fusion/fused_batch_norm_fusion.h | 15 ++++++++-- .../ir_fusion/fused_batch_norm_fusion_test.cc | 29 +++++++++++++++++-- .../fused_batch_norm_fusion_test.py | 17 ++++++++++- 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 800c862c53..76301a8b47 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -239,7 +239,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc index 03428e6357..efc9ee7934 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -291,7 +291,7 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c return bn_training_update_outputs[0]; } -const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const { +const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { std::shared_ptr Xs = std::make_shared(); VarPtr index0 = std::make_shared(IsC); VarPtr index1 = std::make_shared(IsC); @@ -313,5 +313,28 @@ const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const { VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); } + +const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); + VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); + VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); + VectorRef cast0 = VectorRef({prim::kPrimCast, sub0}); + VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); + VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h index e4b31ca5f4..f476e96062 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h @@ -61,12 +61,21 @@ class FusedBatchNormFusion : public PatternProcessPass { VarPtr batch_norm_var_; }; -class FusedBatchNormMixPrecisionFusion : public FusedBatchNormFusion { +class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { public: - explicit FusedBatchNormMixPrecisionFusion(bool multigraph = true) + explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true) : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} - ~FusedBatchNormMixPrecisionFusion() override = default; + ~FusedBatchNormMixPrecisionFusion0() override = default; + const BaseRef DefinePattern() const override; +}; + +class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { + public: + explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true) + : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} + + ~FusedBatchNormMixPrecisionFusion1() override = default; const BaseRef DefinePattern() const override; }; } // namespace opt diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc index f023446698..597b7b18ff 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc @@ -51,8 +51,8 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) { EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision"); +TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion0) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision0"); EXPECT_NE(g, nullptr); std::vector shp_x{32, 64, 112, 112}; auto x_abstract = std::make_shared(kFloat32, shp_x); @@ -66,7 +66,30 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion1) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision1"); + EXPECT_NE(g, nullptr); + std::vector shp_x{32, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + std::vector shp_y{64}; + auto y_abstract = std::make_shared(kFloat32, shp_y); + AbstractBasePtrList args_spec_list{x_abstract}; + for (size_t i = 0; i < 6; ++i) { + args_spec_list.push_back(y_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kg); diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py index 5b286e358b..4df9942bce 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py @@ -61,7 +61,7 @@ def test_fused_batch_norm_fusion(tag): return output @fns - def before_mix_precision(input0, input1, input2, input3, input4, var0, var1): + def before_mix_precision0(input0, input1, input2, input3, input4, var0, var1): batch_norm = BatchNorm(input0, input1, input2, input3, input4) sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) @@ -75,6 +75,21 @@ def test_fused_batch_norm_fusion(tag): output = tuple_getitem(outputs, 0) return output + @fns + def before_mix_precision1(input0, input1, input2, input3, input4, var0, var1): + batch_norm = BatchNorm(input0, input1, input2, input3, input4) + sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) + sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) + mul0 = Mul(Cast(sub0, mstype.float32), constant0) + mul1 = Mul(Cast(sub1, mstype.float32), constant1) + assign_sub0 = AssignSub(var0, mul0) + assign_sub1 = AssignSub(var1, mul1) + depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0) + depend1 = depend(depend0, assign_sub1) + outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) + output = tuple_getitem(outputs, 0) + return output + @fns def after(input0, input1, input2, input3, input4, var0, var1): bn_training_reduce = BNTrainingReduce(input0)