From 0d4866f8da1e3a63ede6d324a75d85e3a98d6999 Mon Sep 17 00:00:00 2001 From: kswang Date: Mon, 12 Oct 2020 09:59:00 +0800 Subject: [PATCH] optimize graph cut for depend --- mindspore/ccsrc/backend/session/session_basic.cc | 9 ++++----- mindspore/ccsrc/vm/segment_runner.cc | 7 +++---- mindspore/core/ir/anf.cc | 10 +++++++--- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ffd5b48dda..84c2eb953b 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -515,8 +515,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, MS_EXCEPTION_IF_NULL(other_graph_cnode); MS_EXCEPTION_IF_NULL(cnode_inputs); auto origin_inputs = cnode->inputs(); - bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && - origin_inputs[kRealInputIndexInDepend]->isa(); + bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3; bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { @@ -526,6 +525,9 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; + } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { + cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); + continue; } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { cnode_inputs->push_back((*other_graph_cnode)[anf]); continue; @@ -545,9 +547,6 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, (*other_graph_cnode)[anf] = new_parameter; } continue; - } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { - cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]); - continue; } else if (optimize_control_depend) { cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); } else { diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index f598db2f2d..c27e3e5673 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -141,10 +141,9 @@ std::tuple TransformSegmentToAnfGr } auto fn = inps[0]; std::vector args{fn}; - if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa() && - eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { - args.emplace_back(inps[kRealInputIndexInDepend]); - args.emplace_back(inps[kRealInputIndexInDepend]); + if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { + args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv)); + args.emplace_back(NewValueNode(MakeValue(0))); } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { for (size_t i = 1; i < inps.size(); ++i) { if (inps[i]->isa() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 0446ab988d..52874d561f 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -289,10 +289,14 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { } return target; } - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + auto &inputs = cnode->inputs(); + if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) { + return GetCNodeTarget(inputs[1]); + } + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { return GetMaketupleNodeTarget(cnode); - } - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { return GetTupleGetItemTarget(cnode, primitive); } return default_target;