From 0590ade91d00671ec327955d0d2985ac50496262 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Sat, 13 Mar 2021 10:22:57 +0800 Subject: [PATCH] modify SetitemEliminator --- .../irpass/item_tuple_or_list_eliminate.h | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h index 5a0dafbff2..4a1b55beb0 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h @@ -191,14 +191,18 @@ class GetitemConstEliminator : public AnfVisitor { // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) -// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} -// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z} +// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, a, b, c, ...}, 0, z} => {prim::kPrimMakeTuple, z, b, c, ...} +// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...} +// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...} +// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...} class SetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimTupleSetItem, {IsVNode, IsVNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimListSetItem, {IsVNode, IsVNode, IsNode})(node); auto fg = node->func_graph(); if (fg != nullptr && z_ != nullptr) { @@ -225,7 +229,27 @@ class SetitemEliminator : public AnfVisitor { } void Visit(const ValueNodePtr &vnode) override { - if (!args_.empty() && IsValueNode(vnode)) { + if (args_.empty() && IsValueNode(vnode)) { + auto tuple = GetValueNode(vnode); + if (tuple != nullptr) { + args_.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (auto &val : tuple->value()) { + auto val_node = std::make_shared(val); + val_node->set_abstract(val->ToAbstract()); + args_.emplace_back(val_node); + } + } + } else if (args_.empty() && IsValueNode(vnode)) { + auto list = GetValueNode(vnode); + if (list != nullptr) { + args_.emplace_back(NewValueNode(prim::kPrimMakeList)); + for (auto &val : list->value()) { + auto val_node = std::make_shared(val); + val_node->set_abstract(val->ToAbstract()); + args_.emplace_back(val_node); + } + } + } else if (!args_.empty() && IsValueNode(vnode)) { auto idx = GetValue(vnode->value()); if (idx < 0) { idx = idx + args_.size() - 1;