!2333 remove addnclass in opt

Merge pull request !2333 from xychow/remove-addnclass-in-opt
pull/2333/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2711a628fb

@ -35,9 +35,6 @@ namespace irpass {
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
class MergeAddN : public AnfVisitor {
public:
MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
~MergeAddN() override = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
optimizer_ = optimizer;
@ -47,15 +44,15 @@ class MergeAddN : public AnfVisitor {
return nullptr;
}
auto fg = node->func_graph();
// {PrimAddNClass}
auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)});
auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
// {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
(void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
auto fg = node->func_graph();
auto make_node = fg->NewCNode(args_);
return fg->NewCNode({addn_node, make_node});
return fg->NewCNode({addn, make_node});
}
void Visit(const CNodePtr &cnode) override {
@ -127,7 +124,6 @@ class MergeAddN : public AnfVisitor {
}
private:
ValuePtr PrimAddN_;
OptimizerPtr optimizer_{nullptr};
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
bool is_inner_{false}, is_outer_{false}, is_match_{false};
@ -136,9 +132,6 @@ class MergeAddN : public AnfVisitor {
// {PrimAddN, {kPrimMakeTuple, Xs}}
class AddNZeroFilter : public AnfVisitor {
public:
AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
~AddNZeroFilter() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
@ -161,8 +154,9 @@ class AddNZeroFilter : public AnfVisitor {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
auto fg = node->func_graph();
auto addn = fg->NewCNode({NewValueNode(PrimAddN_)});
auto make_tuple = fg->NewCNode(filtered_Xs_);
return fg->NewCNode({addn, make_tuple});
}
@ -193,7 +187,6 @@ class AddNZeroFilter : public AnfVisitor {
}
private:
ValuePtr PrimAddN_;
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false};
};

@ -875,7 +875,6 @@ def test_merge_addn(tag):
""" test_merge_addn """
fns = FnDict()
addn = P.AddN()
AddN = P.AddN
@fns
def before(x, y, z, a):
@ -883,7 +882,7 @@ def test_merge_addn(tag):
@fns
def after(x, y, z, a):
return AddN()((a, x, y, z))
return addn((a, x, y, z))
return fns[tag]
@ -892,7 +891,6 @@ def test_addn_zero(tag):
""" test_addn_zero """
fns = FnDict()
addn = P.AddN()
AddN = P.AddN
zero_tensor = Primitive('ZerosLike')
@fns
@ -901,7 +899,7 @@ def test_addn_zero(tag):
@fns
def after(x, y, z, a):
return AddN()((a, z))
return addn((a, z))
@fns
def before_2(x, y, z, a):

Loading…
Cancel
Save