diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index bdefc9bf7c..4a6b5b1e4a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -104,6 +104,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc index 7265f1b60d..5f0b869644 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc @@ -116,9 +116,116 @@ const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const A return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); } +const BaseRef LambNextMVRuleCond1::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); + auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); + auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond2::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); + auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); + auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond3::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); + auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); + auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); + auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_}); + auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + const BaseRef LambNextMVRuleCond4::DefinePattern() const { const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); @@ -140,13 +247,9 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const { BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); const auto prim_real_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_real_div); VarPtr Xs = std::make_shared(); VarPtr Ys = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); // Two patterns share: real_div0, real_div1, add2_y_ VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h index f28d837dcf..0089c33f87 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h @@ -87,6 +87,33 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass { VarPtr real_div2_var_; }; +class LambNextMVRuleCond1 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} + + ~LambNextMVRuleCond1() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond2 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} + + ~LambNextMVRuleCond2() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond3 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} + + ~LambNextMVRuleCond3() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + class LambNextMVRuleCond4 : public LambNextMVRule { public: explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc index e5b4c6bc32..6ea622d030 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc @@ -244,5 +244,125 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) { EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); } + +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_unmatched) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "un_match"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + auto origin_graph = std::make_shared(*fg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); +} + +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_unmatched) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "un_match"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + auto origin_graph = std::make_shared(*fg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); +} + +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "before"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_unmatched) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "un_match"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 13; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + auto origin_graph = std::make_shared(*fg); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py index c93c00cd72..5660771723 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py @@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') LambNextMV = Primitive('LambNextMV') - class FnDict: def __init__(self): self.fnDict = {} @@ -35,7 +34,6 @@ class FnDict: def __getitem__(self, name): return self.fnDict[name] - def test_lamb_next_mv_rule_cond4(tag): fns = FnDict() @@ -170,3 +168,192 @@ def test_lamb_next_mv_rule_cond4(tag): return output return fns[tag] + +def test_lamb_next_mv_rule_cond1(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul0 = Mul(constant_mul0_x, input4) + mul1 = Mul(constant_mul1_sub, input3) + add0 = Add(mul0, mul1) + mul2 = Mul(constant_mul2_x, input1) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt0 = Rsqrt(add2) + sqrt1 = Sqrt(real_div1) + add4 = Add(constant_add2_y, sqrt1) + real_div0 = RealDiv(add0, input5) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + mul4 = Mul(constant_mul4_x, input6) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, real_div4) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, + constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, + constant_mul4_x, constant_add2_y) + outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), + tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + @fns + def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul0 = Mul(constant_mul0_x, input4) + mul1 = Mul(constant_mul1_sub, input3) + add0 = Add(mul0, mul1) + mul2 = Mul(constant_mul2_x, input1) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt0 = Rsqrt(add2) + sqrt1 = Sqrt(real_div1) + # un match + add4 = Add(sqrt1, constant_add2_y) + real_div0 = RealDiv(add0, input5) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + mul4 = Mul(constant_mul4_x, input6) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, real_div4) + output = tuple_getitem(outputs, 0) + return output + + return fns[tag] + +def test_lamb_next_mv_rule_cond2(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul0 = Mul(input4, constant_mul0_x) + mul1 = Mul(input3, constant_mul1_sub) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt0 = Rsqrt(add2) + sqrt1 = Sqrt(real_div1) + add4 = Add(sqrt1, constant_add2_y) + real_div0 = RealDiv(add0, input5) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + mul4 = Mul(input6, constant_mul4_x) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, real_div4) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, + constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, + constant_mul4_x, constant_add2_y) + outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), + tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + @fns + def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul0 = Mul(input4, constant_mul0_x) + mul1 = Mul(input3, constant_mul1_sub) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(constant_mul3_sub1, input0) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(constant_add2_y, real_div1) + sqrt0 = Rsqrt(add2) + sqrt1 = Sqrt(real_div1) + # un match + add4 = Add(constant_add2_y, sqrt1) + real_div0 = RealDiv(add0, input5) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + mul4 = Mul(input6, constant_mul4_x) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, real_div4) + output = tuple_getitem(outputs, 0) + return output + + return fns[tag] + +def test_lamb_next_mv_rule_cond3(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul0 = Mul(input4, constant_mul0_x) + mul1 = Mul(input3, constant_mul1_sub) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(input0, constant_mul3_sub1) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(real_div1, constant_add2_y) + sqrt0 = Rsqrt(add2) + sqrt1 = Sqrt(real_div1) + add4 = Add(sqrt1, constant_add2_y) + real_div0 = RealDiv(add0, input5) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + mul4 = Mul(input6, constant_mul4_x) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, real_div4) + output = tuple_getitem(outputs, 0) + return output + + @fns + def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, + constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, + constant_mul4_x, constant_add2_y) + outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), + tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) + output = tuple_getitem(outputs, 0) + return make_tuple(output) + + @fns + def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, + constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): + mul0 = Mul(input4, constant_mul0_x) + mul1 = Mul(input3, constant_mul1_sub) + add0 = Add(mul0, mul1) + mul2 = Mul(input1, constant_mul2_x) + mul3 = Mul(input0, constant_mul3_sub1) + add1 = Add(mul2, mul3) + real_div1 = RealDiv(add1, input2) + add2 = Add(real_div1, constant_add2_y) + sqrt0 = Rsqrt(add2) + sqrt1 = Sqrt(real_div1) + # un match + add4 = Add(constant_add2_y, sqrt1) + real_div0 = RealDiv(add0, input5) + real_div4 = RealDiv(real_div0, add4) + real_div2 = Mul(sqrt0, real_div0) + mul4 = Mul(input6, constant_mul4_x) + add3 = Add(mul4, real_div2) + outputs = make_tuple(add3, add0, add1, real_div4) + output = tuple_getitem(outputs, 0) + return output + + return fns[tag]