diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index 3c1d26d508..400176d59c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -43,19 +43,40 @@ FuncGraphManagerPtr GetManager(const AnfNodePtr &node) { return fg->manager(); } -// Return true if the node is only used by the given update_state node. -bool OnlyUpdateStateUse(const CNodePtr &update_state_node, const AnfNodePtr &node) { - auto mgr = GetManager(update_state_node); +// Return true if the node(be_used_node) is only used by the given node. +bool OnlyUsedByOneNode(const AnfNodePtr &be_used_node, const CNodePtr &given_node) { + auto mgr = GetManager(given_node); if (mgr == nullptr) { return false; } auto &node_users = mgr->node_users(); - auto iter = node_users.find(node); + auto iter = node_users.find(be_used_node); if (iter == node_users.end()) { return false; } auto &partial_users = iter->second; - return (partial_users.size() == 1) && (partial_users.front().first == update_state_node); + return (partial_users.size() == 1) && (partial_users.front().first == given_node); +} + +// Return true if the node(be_used_node) is only used by the given two nodes(first_node and second_node). +bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_node, const AnfNodePtr &second_node) { + auto mgr = GetManager(be_used_node); + if (mgr == nullptr || first_node == second_node) { + return false; + } + auto &node_users = mgr->node_users(); + auto iter = node_users.find(be_used_node); + if (iter == node_users.end()) { + return false; + } + auto &partial_users = iter->second; + if (partial_users.size() != 2) { + return false; + } + const auto &first_user = partial_users.front().first; + const auto &second_user = partial_users.back().first; + return (first_user == first_node && second_user == second_node) || + (first_user == second_node && second_user == first_node); } // Eliminate useless node that only used by associated update_state. @@ -66,7 +87,7 @@ bool OnlyUpdateStateUse(const CNodePtr &update_state_node, const AnfNodePtr &nod // To: // user(u) AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const AnfNodePtr &node) { - if (!OnlyUpdateStateUse(update_state, node)) { + if (!OnlyUsedByOneNode(node, update_state)) { // Skip if UpdateState is not the only user of cnode. return nullptr; } @@ -245,7 +266,7 @@ void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std: void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *update_states, std::vector *loads) { - if (!OnlyUpdateStateUse(update_state, make_tuple)) { + if (!OnlyUsedByOneNode(make_tuple, update_state)) { // UpdateState should be the only user of make_tuple. return; } @@ -420,7 +441,8 @@ AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, cons auto para2 = a2_cnode->input(kInputIndex); auto value2 = a2_cnode->input(kAttachIndex); auto u2 = a2_cnode->input(kAssignMonadInputIndex); - if (IsPrimitiveCNode(u2, prim::kPrimUpdateState)) { + // u2 is UpdateState, u2 only be used by a2 + if (IsPrimitiveCNode(u2, prim::kPrimUpdateState) && OnlyUsedByOneNode(u2, a2_cnode)) { auto a1 = u2->cast()->input(kAttachIndex); if (IsPrimitiveCNode(a1, prim::kPrimAssign)) { auto a1_cnode = a1->cast(); @@ -470,9 +492,10 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta auto para3 = a3_cnode->input(kInputIndex); auto value3 = a3_cnode->input(kAttachIndex); auto u3 = a3_cnode->input(kAssignMonadInputIndex); - if (IsPrimitiveCNode(u3, prim::kPrimUpdateState)) { - auto make_tuple = u3->cast()->input(kAttachIndex); - if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) { + if (IsPrimitiveCNode(u3, prim::kPrimUpdateState) && OnlyUsedByOneNode(u3, a3_cnode)) { + auto u3_cnode = u3->cast(); + auto make_tuple = u3_cnode->input(kAttachIndex); + if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple) && OnlyUsedByOneNode(make_tuple, u3_cnode)) { auto make_tuple_cnode = make_tuple->cast(); if (make_tuple_cnode->size() != kMakeTupleSize) { return nullptr; @@ -531,7 +554,7 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta // u4 = UpdateState(u1, t) AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &make_tuple) { auto make_tuple_cnode = make_tuple->cast(); - if (make_tuple_cnode->size() != kMakeTupleSize) { + if (make_tuple_cnode->size() != kMakeTupleSize || !OnlyUsedByOneNode(make_tuple, update_state)) { return nullptr; } auto a2 = make_tuple_cnode->input(kInputIndex); @@ -545,7 +568,7 @@ AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_sta auto para2 = a2_cnode->input(kInputIndex); auto value2 = a2_cnode->input(kAttachIndex); auto u2 = a2_cnode->input(kAssignMonadInputIndex); - if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState)) { + if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState) || !OnlyUsedByTwoNode(u2, a2, a3)) { return nullptr; } auto para3 = a3_cnode->input(kInputIndex);