From b29d260be3503105ea3861a860bb6bf787eb3c6c Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Fri, 19 Jun 2020 02:57:03 +0000 Subject: [PATCH] reuse AddN primitive in opt as AddN will replicated by program_specialize --- mindspore/ccsrc/optimizer/irpass/merge_addn.h | 19 ++++++------------- .../gtest_input/optimizer/opt_test.py | 6 ++---- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/optimizer/irpass/merge_addn.h index c96c35407f..e1e4b8878b 100644 --- a/mindspore/ccsrc/optimizer/irpass/merge_addn.h +++ b/mindspore/ccsrc/optimizer/irpass/merge_addn.h @@ -35,9 +35,6 @@ namespace irpass { // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} class MergeAddN : public AnfVisitor { public: - MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {} - ~MergeAddN() override = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { Reset(); optimizer_ = optimizer; @@ -47,15 +44,15 @@ class MergeAddN : public AnfVisitor { return nullptr; } - auto fg = node->func_graph(); - // {PrimAddNClass} - auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)}); + auto cnode = node->cast(); + auto addn = NewValueNode(GetValueNode(cnode->input(0))); // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); + auto fg = node->func_graph(); auto make_node = fg->NewCNode(args_); - return fg->NewCNode({addn_node, make_node}); + return fg->NewCNode({addn, make_node}); } void Visit(const CNodePtr &cnode) override { @@ -127,7 +124,6 @@ class MergeAddN : public AnfVisitor { } private: - ValuePtr PrimAddN_; OptimizerPtr optimizer_{nullptr}; std::vector Xs_{}, Ys_{}, args_{}; bool is_inner_{false}, is_outer_{false}, is_match_{false}; @@ -136,9 +132,6 @@ class MergeAddN : public AnfVisitor { // {PrimAddN, {kPrimMakeTuple, Xs}} class AddNZeroFilter : public AnfVisitor { public: - AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {} - ~AddNZeroFilter() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); @@ -161,8 +154,9 @@ class AddNZeroFilter : public AnfVisitor { return nullptr; } + auto cnode = node->cast(); + auto addn = NewValueNode(GetValueNode(cnode->input(0))); auto fg = node->func_graph(); - auto addn = fg->NewCNode({NewValueNode(PrimAddN_)}); auto make_tuple = fg->NewCNode(filtered_Xs_); return fg->NewCNode({addn, make_tuple}); } @@ -193,7 +187,6 @@ class AddNZeroFilter : public AnfVisitor { } private: - ValuePtr PrimAddN_; std::vector filtered_Xs_{}, Xs_{}; bool has_zero_like_{false}; }; diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 22e2535819..16c557adbe 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -875,7 +875,6 @@ def test_merge_addn(tag): """ test_merge_addn """ fns = FnDict() addn = P.AddN() - AddN = P.AddN @fns def before(x, y, z, a): @@ -883,7 +882,7 @@ def test_merge_addn(tag): @fns def after(x, y, z, a): - return AddN()((a, x, y, z)) + return addn((a, x, y, z)) return fns[tag] @@ -892,7 +891,6 @@ def test_addn_zero(tag): """ test_addn_zero """ fns = FnDict() addn = P.AddN() - AddN = P.AddN zero_tensor = Primitive('ZerosLike') @fns @@ -901,7 +899,7 @@ def test_addn_zero(tag): @fns def after(x, y, z, a): - return AddN()((a, z)) + return addn((a, z)) @fns def before_2(x, y, z, a):