From 2ed4ad0f2aac1b29cae2e1ad35b119da39896ddc Mon Sep 17 00:00:00 2001 From: huanghui Date: Wed, 22 Apr 2020 09:52:45 +0800 Subject: [PATCH] optimize_dependece pass enhance --- .../pre_activate/pass/optimize_dependence.cc | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc index db32354abf..86a90a4dfe 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc @@ -28,8 +28,7 @@ namespace mindspore { namespace opt { constexpr auto kSingleInputIndex = 1; namespace { -AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); +AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return nullptr; @@ -41,15 +40,6 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { return nullptr; } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - // Check whether the node has only one output node. - if (manager->node_users().find(cnode) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The node should be used by at least another node's input"; - } - if (manager->node_users()[cnode].size() > 1) { - return nullptr; - } CheckCNodeInputSize(cnode, kSingleInputIndex + 1); return cnode->input(kSingleInputIndex); } @@ -63,7 +53,7 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { std::vector new_make_tuple_inputs; bool need_update = false; for (const auto &input : cnode->inputs()) { - AnfNodePtr replace_input = GetReplaceNode(func_graph, input); + AnfNodePtr replace_input = GetReplaceNode(input); // If replace input is not null, it will be the input of the TransData or Cast. if (replace_input == nullptr) { new_make_tuple_inputs.push_back(input); @@ -119,7 +109,7 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con if (ReplaceMakeTuple(func_graph, replacing_cnode)) { return nullptr; } - AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); if (replace_node == nullptr) { MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); return nullptr;