diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index f6feb0440f..52f61f7027 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -99,11 +99,11 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc index f0a55c6ff4..7265f1b60d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc @@ -27,82 +27,12 @@ namespace mindspore { namespace opt { -namespace { -std::tuple GetSharedNodesByPattern(const AnfNodePtr &node) { - auto add3_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kAddInputNum); - MS_EXCEPTION_IF_NULL(add3_cnode); - auto real_div2_cnode = CheckAnfNodeIfCNodeAndInputSize(add3_cnode->input(1), kMulInputNum); - MS_EXCEPTION_IF_NULL(real_div2_cnode); - auto real_div0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(1), kRealDivInputNum); - MS_EXCEPTION_IF_NULL(real_div0_cnode); - auto sqrt0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(2), kSqrtInputNum); - MS_EXCEPTION_IF_NULL(sqrt0_cnode); - auto add2_cnode = CheckAnfNodeIfCNodeAndInputSize(sqrt0_cnode->input(1), kAddInputNum); - MS_EXCEPTION_IF_NULL(add2_cnode); - auto real_div1_cnode = CheckAnfNodeIfCNodeAndInputSize(add2_cnode->input(1), kRealDivInputNum); - auto constant_add2_y = add2_cnode->input(2); - - return std::make_tuple(real_div0_cnode, real_div1_cnode, constant_add2_y); -} - -bool MatchRealDiv4(const AnfNodePtr &real_div4, const AnfNodePtr &real_div1, const AnfNodePtr &constant_add2_y) { - if (real_div4 == nullptr || !real_div4->isa()) { - return false; - } - auto real_div4_cnode = real_div4->cast(); - MS_EXCEPTION_IF_NULL(real_div4_cnode); - if (AnfAlgo::GetCNodeName(real_div4_cnode) != kRealDivOpName || real_div4_cnode->inputs().size() < kRealDivInputNum) { - return false; - } - - CNodePtr add4_cnode = nullptr; - if (!CheckIfCNodeAndInputSize(real_div4_cnode->input(2), kAddInputNum, &add4_cnode) || - AnfAlgo::GetCNodeName(add4_cnode) != prim::kPrimTensorAdd->name()) { - return false; - } - CNodePtr sqrt1_cnode = nullptr; - if (!CheckIfCNodeAndInputSize(add4_cnode->input(1), kSqrtInputNum, &sqrt1_cnode) || - AnfAlgo::GetCNodeName(sqrt1_cnode) != kSqrtOpName) { - return false; - } - - MS_EXCEPTION_IF_NULL(add4_cnode->input(2)); - MS_EXCEPTION_IF_NULL(constant_add2_y); - return sqrt1_cnode->input(1) == real_div1 && *(add4_cnode->input(2)) == *constant_add2_y; -} -} // namespace - -const BaseRef LambNextMVRule::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - - auto mul0 = VectorRef({prim::kPrimMul, input_varptr_[7], input_varptr_[4]}); - auto mul1 = VectorRef({prim::kPrimMul, input_varptr_[8], input_varptr_[3]}); - auto mul2 = VectorRef({prim::kPrimMul, input_varptr_[9], input_varptr_[1]}); - auto mul3 = VectorRef({prim::kPrimMul, input_varptr_[10], input_varptr_[0]}); - auto mul4 = VectorRef({prim::kPrimMul, input_varptr_[11], input_varptr_[6]}); - auto add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1}); - auto add1 = VectorRef({prim::kPrimTensorAdd, mul2, mul3}); - - auto real_div0 = VectorRef({prim_deal_div, add0, input_varptr_[5]}); - auto real_div1 = VectorRef({prim_deal_div, add1, input_varptr_[2]}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, input_varptr_[12]}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); - - return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); -} - -bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, +bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, std::vector *old_pattern_outputs) const { MS_EXCEPTION_IF_NULL(func_graph); - CNodePtr real_div0 = nullptr; - CNodePtr real_div1 = nullptr; - AnfNodePtr constant_add2_y = nullptr; - std::tie(real_div0, real_div1, constant_add2_y) = GetSharedNodesByPattern(node); + MS_EXCEPTION_IF_NULL(equiv); + auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_); + auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_); auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -112,19 +42,17 @@ bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNode } AnfNodeIndexSet real_div0_outputs = users[real_div0]; auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(), - [&node, &real_div1, &constant_add2_y](const std::pair &node_index) { - return node_index.first != node && node_index.second == 1 && - MatchRealDiv4(node_index.first, real_div1, constant_add2_y); + [&real_div2, &equiv, this](const std::pair &node_index) { + return node_index.first != real_div2 && node_index.second == 1 && + MatchAnotherPattern(node_index.first, equiv); }); if (iter == real_div0_outputs.end()) { return false; } - auto add0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div0->input(1), kAddInputNum); - auto add1_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div1->input(1), kAddInputNum); (*old_pattern_outputs).push_back(node); - (*old_pattern_outputs).push_back(add0_cnode); - (*old_pattern_outputs).push_back(add1_cnode); + (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_)); + (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_)); (*old_pattern_outputs).push_back(iter->first); return true; @@ -136,8 +64,19 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(func_graph); auto prim = std::make_shared(kLambNextMVOpName); std::vector lamb_next_mv_rule_inputs = {NewValueNode(prim)}; - (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(lamb_next_mv_rule_inputs), - [&equiv](const VarPtr &in) { return utils::cast((*equiv)[in]); }); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input0_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input1_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input2_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input3_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input4_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input5_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input6_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul0_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul1_sub_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul2_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul3_sub1_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul4_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[add2_y_])); auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs); MS_EXCEPTION_IF_NULL(lamb_next_mv_rule); @@ -162,14 +101,60 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph, return lamb_next_mv_rule_outputs[0]; } +bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { + return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) && + IsSameNode(equiv1, equiv2, add2_y_); +} + const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { std::vector old_pattern_outputs; - if (!IsRuleMatched(func_graph, node, &old_pattern_outputs)) { + if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { return nullptr; } return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); } + +const BaseRef LambNextMVRuleCond4::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + + auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); + auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); + auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); + + return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); +} + +BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_real_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_real_div); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h index 33fb41662d..f28d837dcf 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h @@ -29,23 +29,71 @@ namespace mindspore { namespace opt { -class LambNextMVRule : public PatternProcessPass { +class LambNextMVRule : public MultipleOutputPatternProcessPass { public: - explicit LambNextMVRule(bool multigraph = true) : PatternProcessPass("lamb_next_mv_rule", multigraph) { - for (size_t i = 0; i < kLambNextMVRuleInputNum - 1; ++i) { - input_varptr_.push_back(std::make_shared()); - } + explicit LambNextMVRule(const std::string &name = "", bool multigraph = true) + : MultipleOutputPatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + input6_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_sub_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_sub1_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div2_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); } ~LambNextMVRule() override = default; - const BaseRef DefinePattern() const override; + const BaseRef DefinePattern() const override = 0; + BaseRef DefineAnotherPattern() const override = 0; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; - private: - std::vector input_varptr_; - bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + protected: + bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, std::vector *old_pattern_outputs) const; AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector &old_pattern_outputs, const EquivPtr &equiv) const; + + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr input6_; + VarPtr mul0_x_; + VarPtr mul1_sub_; + VarPtr mul2_x_; + VarPtr mul3_sub1_; + VarPtr mul4_x_; + VarPtr add2_y_; + // nodes which two patterns share, and add2_y_ also. + VarPtr real_div0_var_; + VarPtr real_div1_var_; + // part of output nodes + VarPtr add0_var_; + VarPtr add1_var_; + // other node + VarPtr real_div2_var_; +}; + +class LambNextMVRuleCond4 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} + + ~LambNextMVRuleCond4() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc index dc723f3052..e0389309a1 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -79,63 +79,6 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); } -const BaseRef LambNextMVWithDecayRule::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); - return add5; -} - -const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); - return add3; -} - -bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - VarPtr fg = std::make_shared("RootG"); - auto empty_equiv = std::make_shared(); - MS_EXCEPTION_IF_NULL(child_primitive_vars_); - EquivPtr another_equiv = - child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, - *child_primitive_vars_, empty_equiv); - if (another_equiv != nullptr && !another_equiv->empty()) { - return IsShareNodes(equiv, another_equiv); - } - return false; -} - bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); @@ -164,7 +107,7 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph return nullptr; } -const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { +BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { const auto prim_rsqrt = std::make_shared(kRsqrtOpName); MS_EXCEPTION_IF_NULL(prim_rsqrt); VarPtr Xs = std::make_shared(); @@ -205,7 +148,7 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { return add5; } -const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { +BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { const auto prim_rsqrt = std::make_shared(kRsqrtOpName); MS_EXCEPTION_IF_NULL(prim_rsqrt); VarPtr Xs = std::make_shared(); @@ -246,7 +189,7 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { return add5; } -const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { +BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { const auto prim_rsqrt = std::make_shared(kRsqrtOpName); MS_EXCEPTION_IF_NULL(prim_rsqrt); VarPtr Xs = std::make_shared(); @@ -286,5 +229,47 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); return add5; } + +BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); + return add5; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h index 4d52451a07..5d61975197 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h @@ -24,15 +24,10 @@ namespace mindspore { namespace opt { -class LambNextMVWithDecayRule : public PatternProcessPass { +class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass { public: - explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4", - bool multigraph = true) - : PatternProcessPass(name, multigraph), - child_pattern_engine_(PatternEngine(std::make_shared(), - std::function(AnfEqual), - std::function(CNodeTypeEqual))), - child_primitive_vars_(std::make_shared()) { + explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true) + : MultipleOutputPatternProcessPass(name, multigraph) { for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { input_vars_.push_back(std::make_shared()); } @@ -48,21 +43,16 @@ class LambNextMVWithDecayRule : public PatternProcessPass { } ~LambNextMVWithDecayRule() override = default; - const BaseRef DefinePattern() const override; - virtual const BaseRef DefineAnotherPattern() const; + const BaseRef DefinePattern() const override = 0; + BaseRef DefineAnotherPattern() const override = 0; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; protected: - bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; - // check two patterns whether share the same nodes or not - bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const; - AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; - PatternEngine child_pattern_engine_; - PrimitiveVarMapPtr child_primitive_vars_; std::vector input_vars_; std::vector constant_mul_input_vars_; // nodes which two patterns share @@ -82,7 +72,7 @@ class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { ~LambNextMVWithDecayRuleCond1() override = default; const BaseRef DefinePattern() const override; - const BaseRef DefineAnotherPattern() const override; + BaseRef DefineAnotherPattern() const override; }; class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { @@ -92,7 +82,7 @@ class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { ~LambNextMVWithDecayRuleCond2() override = default; const BaseRef DefinePattern() const override; - const BaseRef DefineAnotherPattern() const override; + BaseRef DefineAnotherPattern() const override; }; class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { @@ -102,7 +92,17 @@ class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { ~LambNextMVWithDecayRuleCond3() override = default; const BaseRef DefinePattern() const override; - const BaseRef DefineAnotherPattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond4(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {} + + ~LambNextMVWithDecayRuleCond4() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.cc b/mindspore/ccsrc/pre_activate/common/optimizer.cc index 2711d87721..fa51a0bd8c 100644 --- a/mindspore/ccsrc/pre_activate/common/optimizer.cc +++ b/mindspore/ccsrc/pre_activate/common/optimizer.cc @@ -62,6 +62,21 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode return nullptr; } +bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + VarPtr fg = std::make_shared("RootG"); + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(child_primitive_vars_); + EquivPtr another_equiv = + child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, + *child_primitive_vars_, empty_equiv); + if (another_equiv != nullptr && !another_equiv->empty()) { + return IsShareNodes(equiv, another_equiv); + } + return false; +} + void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { if (pass_manager != nullptr) { pass_managers_.push_back(pass_manager); diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.h b/mindspore/ccsrc/pre_activate/common/optimizer.h index cec23ae178..1f9961df6b 100644 --- a/mindspore/ccsrc/pre_activate/common/optimizer.h +++ b/mindspore/ccsrc/pre_activate/common/optimizer.h @@ -51,6 +51,25 @@ class PatternProcessPass : public NodePass { PrimitiveVarMapPtr primitive_vars_; }; +class MultipleOutputPatternProcessPass : public PatternProcessPass { + public: + explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) + : PatternProcessPass(name, multigraph), + child_pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + child_primitive_vars_(std::make_shared()) {} + ~MultipleOutputPatternProcessPass() override = default; + virtual BaseRef DefineAnotherPattern() const = 0; + // check two patterns whether share the same nodes or not + virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; + + protected: + bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; + PatternEngine child_pattern_engine_; + PrimitiveVarMapPtr child_primitive_vars_; +}; + class GraphOptimizer { public: explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc index 797ca08b76..e5b4c6bc32 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc @@ -30,7 +30,7 @@ class TestHWLambNextMVRule : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; -TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) { +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_matched) { /* * def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, * constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -54,7 +54,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) { * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -65,15 +65,15 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) { +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div4) { /* * def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, * constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -97,7 +97,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) { * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div4"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div4"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -109,14 +109,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); } -TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) { +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div2) { /* * def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, * constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -140,7 +140,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) { * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div2"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div2"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -152,14 +152,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); } -TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) { +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div0) { /* * def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, * constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -183,7 +183,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) { * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div0"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div0"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -195,14 +195,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); } -TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) { +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) { /* * def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, * constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -226,7 +226,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) { * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div1"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div1"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -238,7 +238,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc index 53ebbd6f2f..36f0321511 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc @@ -30,7 +30,7 @@ class TestHWLambNextMVWithDecayRule : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; -TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) { +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_matched) { /* * def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, * constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -55,7 +55,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) { * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -66,15 +66,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add3) { +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_add3) { /* * def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, * constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -99,7 +99,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_add3"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_add3"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -111,15 +111,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after"); EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul4) { +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_mul4) { /* * def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, * constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -144,7 +144,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_mul4"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_mul4"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -156,15 +156,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after"); EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div0) { +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div0) { /* * def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, * constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -189,7 +189,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div0"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div0"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -201,15 +201,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after"); EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div1) { +TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div1) { /* * def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, * constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): @@ -234,7 +234,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea * output = tuple_getitem(outputs, 0) * return output */ - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div1"); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div1"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); AbstractBasePtrList args_spec_list; @@ -246,11 +246,11 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after"); EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); } diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py index 21b236d694..c93c00cd72 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py @@ -36,7 +36,7 @@ class FnDict: return self.fnDict[name] -def test_lamb_next_mv_rule(tag): +def test_lamb_next_mv_rule_cond4(tag): fns = FnDict() @fns diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py index d2931cce36..87dba9b3f7 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py @@ -34,7 +34,7 @@ class FnDict: def __getitem__(self, name): return self.fnDict[name] -def test_lamb_next_mv_with_decay_rule(tag): +def test_lamb_next_mv_with_decay_rule_cond4(tag): fns = FnDict() @fns