From 986408babf6fef767962e5f017225ca8de4ac0c4 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Tue, 4 Aug 2020 16:25:27 +0800 Subject: [PATCH] ignore create parameter from control depend inputs --- .../ccsrc/backend/session/session_basic.cc | 32 ++++++++++-- mindspore/ccsrc/vm/segment_runner.cc | 52 ++++++++++++++----- 2 files changed, 67 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 21ff3180e3..c95faf90bc 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) { } return false; } + +bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &node_inputs = cnode->inputs(); + for (size_t i = 1; i < node_inputs.size(); ++i) { + if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) { + return false; + } + } + return true; +} } // namespace GraphId SessionBasic::graph_sum_ = 0; @@ -354,8 +370,11 @@ std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr MS_EXCEPTION_IF_NULL(graph); std::vector parameters; std::vector pre_graph_out = {node}; + if (IgnoreCreateParameterForMakeTuple(node)) { + pre_graph_out.clear(); + } // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive - if (!AnfAlgo::IsRealKernel(node)) { + if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) { pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); } auto valid_inputs = graph->MutableValidInputs(); @@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; auto parameters = CreateParameterFromTuple(anf, valid_input, graph); if (parameters.empty()) { - MS_LOG(EXCEPTION) << "No parameter exist!!"; + MS_LOG(INFO) << "Empty parameter from cnode"; + return nullptr; } if (parameters.size() == 1) { return parameters[0]; @@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); continue; } else if (optimize_control_depend) { - cnode_inputs.push_back(NewValueNode(MakeValue(input_idx))); + cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); } else { *from_other_graph = true; // the input node is a cnode from other graph auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); + if (parameter_from_cnode == nullptr) { + parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); + } cnode_inputs.push_back(parameter_from_cnode); (*other_graph_cnode)[anf] = parameter_from_cnode; } @@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap auto tensor = inputs[i]; MS_EXCEPTION_IF_NULL(tensor); auto input_node = input_nodes[i]; - if (TensorNeedSync(input_node, tensor) && input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(input_node->cast())) { diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 151c20a535..141adc1bff 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -79,6 +79,42 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, c return output; } +namespace { +AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr, + AnfNodePtrToAnfNodePtrMap *eqv_ptr) { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(inputs_ptr); + MS_EXCEPTION_IF_NULL(eqv_ptr); + MS_EXCEPTION_IF_NULL(node); + auto &inputs = *inputs_ptr; + auto &eqv = *eqv_ptr; + if (node->isa() && !IsValueNode(node)) { + eqv[node] = node; + } else if (eqv.find(node) == eqv.end()) { + bool ignore_make_tuple = false; + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + ignore_make_tuple = true; + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &node_inputs = cnode->inputs(); + for (size_t i = 1; i < node_inputs.size(); ++i) { + if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) { + ignore_make_tuple = false; + break; + } + } + } + if (!ignore_make_tuple) { + inputs.push_back(node); + } + eqv[node] = fg->add_parameter(); + eqv[node]->set_abstract(node->abstract()); + eqv[node]->set_kernel_info(node->kernel_info_ptr()); + } + return eqv[node]; +} +} // namespace + std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { auto fg = std::make_shared(); AnfNodePtrList inputs; @@ -86,17 +122,6 @@ std::tuple TransformSegmentToAnfGr if (lst.empty()) { MS_LOG(EXCEPTION) << "Input anf node list is empty"; } - auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { - if (a->isa() && !IsValueNode(a)) { - eqv[a] = a; - } else if (eqv.find(a) == eqv.end()) { - inputs.push_back(a); - eqv[a] = fg->add_parameter(); - eqv[a]->set_abstract(a->abstract()); - eqv[a]->set_kernel_info(a->kernel_info_ptr()); - } - return eqv[a]; - }; // Merge CNodes into a AnfGraph that represents a linear instruction segment for (auto n : lst) { if (!n->isa()) { @@ -122,11 +147,12 @@ std::tuple TransformSegmentToAnfGr if (inps[i]->isa() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { args.emplace_back(NewValueNode(MakeValue(i))); } else { - args.emplace_back(ref(inps[i])); + args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv)); } } } else { - (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); + (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), + [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); } eqv[n] = fg->NewCNode(args); eqv[n]->set_abstract(n->abstract());