From 9b26c210f4d92af3440a08ceeefbf362339f178f Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Thu, 25 Feb 2021 09:45:21 +0800 Subject: [PATCH] Eliminate all useless nodes related to UpdateStates. --- .../ccsrc/backend/session/session_basic.cc | 5 +- .../optimizer/irpass/incorporate_getitem.h | 4 +- .../optimizer/irpass/updatestate_eliminate.cc | 51 +++++++++++++++++-- .../pipeline_transformer.cc | 11 ++-- .../device/ascend/tasksink/task_generator.cc | 2 +- tests/ut/cpp/ir/manager_test.cc | 2 +- 6 files changed, 58 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 997aa03d62..a839eaddeb 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1875,10 +1875,9 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) { std::vector ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager, const AnfNodePtr &front_node) { - auto node_users = front_func_graph_manager->node_users(); - auto users = node_users[front_node]; + auto &users = front_func_graph_manager->node_users()[front_node]; std::vector result; - for (auto user : users) { + for (auto &user : users) { if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) { continue; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 0e898c8443..06ccb1e612 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -433,12 +433,12 @@ class IncorporateGetitemSwitch : public AnfVisitor { MS_EXCEPTION_IF_NULL(switch_call_cnode); auto manager = fg->manager(); MS_EXCEPTION_IF_NULL(manager); - auto node_users_map = manager->node_users(); + auto &node_users_map = manager->node_users(); auto it = node_users_map.find(switch_call); if (it == node_users_map.end()) { return false; } - auto node_users = it->second; + auto &node_users = it->second; // If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair &user) { return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index e16143b968..5eb97667cc 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -291,6 +291,50 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd return make_tuple; } +// Remove all nodes related to UpdateStates, if they're redundant. +void EliminateUselessNodesForUpdateStates(const std::vector &update_states) { + if (update_states.empty()) { + return; + } + auto mgr = GetManager(update_states[0]); + + // 1. Remove the use of UpdateState nodes, except the last one. + for (auto i = update_states.size() - 1; i > 0; i--) { + auto &us = update_states[i]; + mgr->Replace(us, us->input(kInputIndex)); + } + + // 2. Remove the Depend users of last UpdateState node. + auto &node_users = mgr->node_users(); + auto iter = node_users.find(update_states[0]); + if (iter == node_users.end()) { + return; + } + auto &us_users = iter->second; + if (us_users.size() < 2) { + return; + } + std::vector depend_nodes; + for (auto &user : us_users) { + if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) { + depend_nodes.emplace_back(user.first); + } + } + if (depend_nodes.empty()) { + return; + } + ssize_t end = 0; + // If all users are Depend CNode. + if (depend_nodes.size() == us_users.size()) { + end = 1; + } + for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) { + const auto &depend_node = depend_nodes[i]; + const auto &depend_cnode = depend_node->cast(); + mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex)); + } +} + // Eliminate UpdateStates for consecutive Loads. // Convert: // x1 = Load(x1, u) @@ -336,10 +380,9 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const mgr->SetEdge(load, kAttachIndex, input_monad); } } - for (auto i = update_states.size() - 1; i > 0; i--) { - auto &us = update_states[i]; - mgr->Replace(us, us->input(kInputIndex)); - } + + EliminateUselessNodesForUpdateStates(update_states); + if (make_tuple_inputs.size() == 1) { // This should not happen. MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 6c4c4f58bd..40e0ba43f3 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -52,25 +52,24 @@ static bool IsInWhiteList(const CNodePtr &cnode) { return false; } -static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) { - auto node_users = node_users_map[node]; +static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { + const auto &node_users = manager->node_users()[node]; for (auto &user_pair : node_users) { auto user_node = user_pair.first; if (!user_node->grad()) { user_node->set_grad(true); - SetGradTag(user_node, node_users_map); + SetGradTag(user_node, manager); } } } void PipelineTransformer::LabelRequiredGradCNode() { auto parameters = root_->parameters(); - auto node_users_map = manager_->node_users(); for (auto parameter : parameters) { if (!ParameterRequireGrad(parameter)) { continue; } - SetGradTag(parameter, node_users_map); + SetGradTag(parameter, manager_); } } @@ -243,7 +242,7 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { while (need_coloring) { need_coloring = false; auto all_nodes = func->nodes(); - auto node_users = manager_->node_users(); + auto &node_users = manager_->node_users(); for (auto &node : all_nodes) { if (node->isa() || node->stage() == -1) { continue; diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc index c01a77b547..8688208ff0 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -58,7 +58,7 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - auto node_users = manager->node_users(); + auto &node_users = manager->node_users(); if (node_users[anf_node_ptr].empty()) { MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; } diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 8b8b0d9151..69583185ee 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -391,7 +391,7 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(2, f->nodes().size()); ASSERT_EQ(1, g->nodes().size()); - auto users = mng->node_users(); + auto &users = mng->node_users(); for (auto& iter : users) { ASSERT_EQ(1, iter.second.size()); }