!14087 fix assign execution order problem

From: @Margaret_wangrui
Reviewed-by: @zh_qh,@hwhewei
Signed-off-by: @zh_qh
pull/14087/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 34586542bc

@ -43,19 +43,40 @@ FuncGraphManagerPtr GetManager(const AnfNodePtr &node) {
return fg->manager(); return fg->manager();
} }
// Return true if the node is only used by the given update_state node. // Return true if the node(be_used_node) is only used by the given node.
bool OnlyUpdateStateUse(const CNodePtr &update_state_node, const AnfNodePtr &node) { bool OnlyUsedByOneNode(const AnfNodePtr &be_used_node, const CNodePtr &given_node) {
auto mgr = GetManager(update_state_node); auto mgr = GetManager(given_node);
if (mgr == nullptr) { if (mgr == nullptr) {
return false; return false;
} }
auto &node_users = mgr->node_users(); auto &node_users = mgr->node_users();
auto iter = node_users.find(node); auto iter = node_users.find(be_used_node);
if (iter == node_users.end()) { if (iter == node_users.end()) {
return false; return false;
} }
auto &partial_users = iter->second; auto &partial_users = iter->second;
return (partial_users.size() == 1) && (partial_users.front().first == update_state_node); return (partial_users.size() == 1) && (partial_users.front().first == given_node);
}
// Return true if the node(be_used_node) is only used by the given two nodes(first_node and second_node).
bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_node, const AnfNodePtr &second_node) {
auto mgr = GetManager(be_used_node);
if (mgr == nullptr || first_node == second_node) {
return false;
}
auto &node_users = mgr->node_users();
auto iter = node_users.find(be_used_node);
if (iter == node_users.end()) {
return false;
}
auto &partial_users = iter->second;
if (partial_users.size() != 2) {
return false;
}
const auto &first_user = partial_users.front().first;
const auto &second_user = partial_users.back().first;
return (first_user == first_node && second_user == second_node) ||
(first_user == second_node && second_user == first_node);
} }
// Eliminate useless node that only used by associated update_state. // Eliminate useless node that only used by associated update_state.
@ -66,7 +87,7 @@ bool OnlyUpdateStateUse(const CNodePtr &update_state_node, const AnfNodePtr &nod
// To: // To:
// user(u) // user(u)
AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const AnfNodePtr &node) { AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const AnfNodePtr &node) {
if (!OnlyUpdateStateUse(update_state, node)) { if (!OnlyUsedByOneNode(node, update_state)) {
// Skip if UpdateState is not the only user of cnode. // Skip if UpdateState is not the only user of cnode.
return nullptr; return nullptr;
} }
@ -245,7 +266,7 @@ void GetLoadsFollowLoad(const CNodePtr &update_state, const CNodePtr &load, std:
void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states, void GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *update_states,
std::vector<CNodePtr> *loads) { std::vector<CNodePtr> *loads) {
if (!OnlyUpdateStateUse(update_state, make_tuple)) { if (!OnlyUsedByOneNode(make_tuple, update_state)) {
// UpdateState should be the only user of make_tuple. // UpdateState should be the only user of make_tuple.
return; return;
} }
@ -420,7 +441,8 @@ AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, cons
auto para2 = a2_cnode->input(kInputIndex); auto para2 = a2_cnode->input(kInputIndex);
auto value2 = a2_cnode->input(kAttachIndex); auto value2 = a2_cnode->input(kAttachIndex);
auto u2 = a2_cnode->input(kAssignMonadInputIndex); auto u2 = a2_cnode->input(kAssignMonadInputIndex);
if (IsPrimitiveCNode(u2, prim::kPrimUpdateState)) { // u2 is UpdateState, u2 only be used by a2
if (IsPrimitiveCNode(u2, prim::kPrimUpdateState) && OnlyUsedByOneNode(u2, a2_cnode)) {
auto a1 = u2->cast<CNodePtr>()->input(kAttachIndex); auto a1 = u2->cast<CNodePtr>()->input(kAttachIndex);
if (IsPrimitiveCNode(a1, prim::kPrimAssign)) { if (IsPrimitiveCNode(a1, prim::kPrimAssign)) {
auto a1_cnode = a1->cast<CNodePtr>(); auto a1_cnode = a1->cast<CNodePtr>();
@ -470,9 +492,10 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta
auto para3 = a3_cnode->input(kInputIndex); auto para3 = a3_cnode->input(kInputIndex);
auto value3 = a3_cnode->input(kAttachIndex); auto value3 = a3_cnode->input(kAttachIndex);
auto u3 = a3_cnode->input(kAssignMonadInputIndex); auto u3 = a3_cnode->input(kAssignMonadInputIndex);
if (IsPrimitiveCNode(u3, prim::kPrimUpdateState)) { if (IsPrimitiveCNode(u3, prim::kPrimUpdateState) && OnlyUsedByOneNode(u3, a3_cnode)) {
auto make_tuple = u3->cast<CNodePtr>()->input(kAttachIndex); auto u3_cnode = u3->cast<CNodePtr>();
if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) { auto make_tuple = u3_cnode->input(kAttachIndex);
if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple) && OnlyUsedByOneNode(make_tuple, u3_cnode)) {
auto make_tuple_cnode = make_tuple->cast<CNodePtr>(); auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
if (make_tuple_cnode->size() != kMakeTupleSize) { if (make_tuple_cnode->size() != kMakeTupleSize) {
return nullptr; return nullptr;
@ -531,7 +554,7 @@ AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_sta
// u4 = UpdateState(u1, t) // u4 = UpdateState(u1, t)
AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &make_tuple) { AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &make_tuple) {
auto make_tuple_cnode = make_tuple->cast<CNodePtr>(); auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
if (make_tuple_cnode->size() != kMakeTupleSize) { if (make_tuple_cnode->size() != kMakeTupleSize || !OnlyUsedByOneNode(make_tuple, update_state)) {
return nullptr; return nullptr;
} }
auto a2 = make_tuple_cnode->input(kInputIndex); auto a2 = make_tuple_cnode->input(kInputIndex);
@ -545,7 +568,7 @@ AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_sta
auto para2 = a2_cnode->input(kInputIndex); auto para2 = a2_cnode->input(kInputIndex);
auto value2 = a2_cnode->input(kAttachIndex); auto value2 = a2_cnode->input(kAttachIndex);
auto u2 = a2_cnode->input(kAssignMonadInputIndex); auto u2 = a2_cnode->input(kAssignMonadInputIndex);
if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState)) { if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState) || !OnlyUsedByTwoNode(u2, a2, a3)) {
return nullptr; return nullptr;
} }
auto para3 = a3_cnode->input(kInputIndex); auto para3 = a3_cnode->input(kInputIndex);

Loading…
Cancel
Save