|
|
|
@ -35,9 +35,6 @@ namespace irpass {
|
|
|
|
|
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
|
|
|
|
|
class MergeAddN : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
|
|
|
|
|
~MergeAddN() override = default;
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
|
|
|
|
Reset();
|
|
|
|
|
optimizer_ = optimizer;
|
|
|
|
@ -47,15 +44,15 @@ class MergeAddN : public AnfVisitor {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
// {PrimAddNClass}
|
|
|
|
|
auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)});
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
|
|
|
|
|
(void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
auto make_node = fg->NewCNode(args_);
|
|
|
|
|
|
|
|
|
|
return fg->NewCNode({addn_node, make_node});
|
|
|
|
|
return fg->NewCNode({addn, make_node});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override {
|
|
|
|
@ -127,7 +124,6 @@ class MergeAddN : public AnfVisitor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
ValuePtr PrimAddN_;
|
|
|
|
|
OptimizerPtr optimizer_{nullptr};
|
|
|
|
|
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
|
|
|
|
|
bool is_inner_{false}, is_outer_{false}, is_match_{false};
|
|
|
|
@ -136,9 +132,6 @@ class MergeAddN : public AnfVisitor {
|
|
|
|
|
// {PrimAddN, {kPrimMakeTuple, Xs}}
|
|
|
|
|
class AddNZeroFilter : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
|
|
|
|
|
~AddNZeroFilter() override = default;
|
|
|
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
Reset();
|
|
|
|
|
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
|
|
|
|
@ -161,8 +154,9 @@ class AddNZeroFilter : public AnfVisitor {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
auto addn = fg->NewCNode({NewValueNode(PrimAddN_)});
|
|
|
|
|
auto make_tuple = fg->NewCNode(filtered_Xs_);
|
|
|
|
|
return fg->NewCNode({addn, make_tuple});
|
|
|
|
|
}
|
|
|
|
@ -193,7 +187,6 @@ class AddNZeroFilter : public AnfVisitor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
ValuePtr PrimAddN_;
|
|
|
|
|
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
|
|
|
|
|
bool has_zero_like_{false};
|
|
|
|
|
};
|
|
|
|
|