!2333 remove addnclass in opt

Merge pull request !2333 from xychow/remove-addnclass-in-opt
pull/2333/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2711a628fb

@ -35,9 +35,6 @@ namespace irpass {
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
class MergeAddN : public AnfVisitor { class MergeAddN : public AnfVisitor {
public: public:
MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
~MergeAddN() override = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset(); Reset();
optimizer_ = optimizer; optimizer_ = optimizer;
@ -47,15 +44,15 @@ class MergeAddN : public AnfVisitor {
return nullptr; return nullptr;
} }
auto fg = node->func_graph(); auto cnode = node->cast<CNodePtr>();
// {PrimAddNClass} auto addn = NewValueNode(GetValueNode(cnode->input(0)));
auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)});
// {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
(void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
auto fg = node->func_graph();
auto make_node = fg->NewCNode(args_); auto make_node = fg->NewCNode(args_);
return fg->NewCNode({addn_node, make_node}); return fg->NewCNode({addn, make_node});
} }
void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
@ -127,7 +124,6 @@ class MergeAddN : public AnfVisitor {
} }
private: private:
ValuePtr PrimAddN_;
OptimizerPtr optimizer_{nullptr}; OptimizerPtr optimizer_{nullptr};
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}; std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
bool is_inner_{false}, is_outer_{false}, is_match_{false}; bool is_inner_{false}, is_outer_{false}, is_match_{false};
@ -136,9 +132,6 @@ class MergeAddN : public AnfVisitor {
// {PrimAddN, {kPrimMakeTuple, Xs}} // {PrimAddN, {kPrimMakeTuple, Xs}}
class AddNZeroFilter : public AnfVisitor { class AddNZeroFilter : public AnfVisitor {
public: public:
AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
~AddNZeroFilter() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
@ -161,8 +154,9 @@ class AddNZeroFilter : public AnfVisitor {
return nullptr; return nullptr;
} }
auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
auto fg = node->func_graph(); auto fg = node->func_graph();
auto addn = fg->NewCNode({NewValueNode(PrimAddN_)});
auto make_tuple = fg->NewCNode(filtered_Xs_); auto make_tuple = fg->NewCNode(filtered_Xs_);
return fg->NewCNode({addn, make_tuple}); return fg->NewCNode({addn, make_tuple});
} }
@ -193,7 +187,6 @@ class AddNZeroFilter : public AnfVisitor {
} }
private: private:
ValuePtr PrimAddN_;
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false}; bool has_zero_like_{false};
}; };

@ -875,7 +875,6 @@ def test_merge_addn(tag):
""" test_merge_addn """ """ test_merge_addn """
fns = FnDict() fns = FnDict()
addn = P.AddN() addn = P.AddN()
AddN = P.AddN
@fns @fns
def before(x, y, z, a): def before(x, y, z, a):
@ -883,7 +882,7 @@ def test_merge_addn(tag):
@fns @fns
def after(x, y, z, a): def after(x, y, z, a):
return AddN()((a, x, y, z)) return addn((a, x, y, z))
return fns[tag] return fns[tag]
@ -892,7 +891,6 @@ def test_addn_zero(tag):
""" test_addn_zero """ """ test_addn_zero """
fns = FnDict() fns = FnDict()
addn = P.AddN() addn = P.AddN()
AddN = P.AddN
zero_tensor = Primitive('ZerosLike') zero_tensor = Primitive('ZerosLike')
@fns @fns
@ -901,7 +899,7 @@ def test_addn_zero(tag):
@fns @fns
def after(x, y, z, a): def after(x, y, z, a):
return AddN()((a, z)) return addn((a, z))
@fns @fns
def before_2(x, y, z, a): def before_2(x, y, z, a):

Loading…
Cancel
Save