From 9aa6d203f14fa87fb6f18ab4fde8942883c4eb84 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Mon, 31 Aug 2020 17:24:58 +0800 Subject: [PATCH] Fix output device address setting for inputs of depend node --- .../ccsrc/backend/session/cpu_session.cc | 4 +- mindspore/ccsrc/backend/session/cpu_session.h | 2 +- .../ccsrc/backend/session/session_basic.cc | 38 ++++++------------- .../ccsrc/backend/session/session_basic.h | 8 ++-- 4 files changed, 18 insertions(+), 34 deletions(-) diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 5397d15808..db5663b31b 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -35,7 +35,7 @@ namespace mindspore { namespace session { -ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { +ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(graph); if (!anf->isa()) { @@ -49,7 +49,7 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, ParameterPtr new_parameter = graph->NewParameter(anf->cast()); TraceManager::EndTrace(); graph_inputs->push_back(new_parameter); - valid_inputs->push_back(valid_input); + valid_inputs->push_back(true); return new_parameter; } diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 08e09d929e..3533853f88 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -37,7 +37,7 @@ class CPUSession : public SessionBasic { std::map *tensor_to_node) override; protected: - ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; + ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override; void Optimize(const std::shared_ptr &kernel_graph); private: diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 945f037ec8..bfa7f367ca 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -395,8 +395,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const } } -std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, - KernelGraph *graph) { +std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); std::vector parameters; @@ -418,7 +417,7 @@ std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr parameter->set_abstract(abstract); auto new_parameter = graph->NewParameter(parameter); parameters.push_back(new_parameter); - valid_inputs->push_back(valid_input); + valid_inputs->push_back(true); graph_inputs->push_back(new_parameter); }; for (const auto &out_node : pre_graph_out) { @@ -442,8 +441,7 @@ std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr return parameters; } -ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, - KernelGraph *graph) { +ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); if (!anf->isa()) { MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; @@ -471,15 +469,15 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf TraceManager::EndTrace(); } graph_inputs->push_back(new_parameter); - valid_inputs->push_back(valid_input); + valid_inputs->push_back(true); return new_parameter; } -AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { +AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; - auto parameters = CreateParameterFromTuple(anf, valid_input, graph); + auto parameters = CreateParameterFromTuple(anf, graph); if (parameters.empty()) { MS_LOG(INFO) << "Empty parameter from cnode"; return nullptr; @@ -495,14 +493,11 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool return make_tuple; } -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, - bool *from_other_graph, +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, std::unordered_map *other_graph_cnode) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(from_other_graph); MS_EXCEPTION_IF_NULL(other_graph_cnode); - *from_other_graph = false; // get primitive of old node std::vector cnode_inputs; auto prim = AnfAlgo::GetCNodePrimitive(cnode); @@ -544,7 +539,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K } continue; } else if (anf->isa()) { - auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); + auto new_parameter = CreateNewParameterFromParameter(anf, graph); cnode_inputs.push_back(new_parameter); if (GetGraphIdByNode(anf) == kInvalidGraphId) { graph->FrontBackendlMapAdd(anf, new_parameter); @@ -558,9 +553,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K } else if (optimize_control_depend) { cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); } else { - *from_other_graph = true; // the input node is a cnode from other graph - auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); + auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph); if (parameter_from_cnode == nullptr) { parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); } @@ -587,7 +581,7 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra } else { KernelGraphPtr kernel_graph = NewKernelGraph(); MS_EXCEPTION_IF_NULL(kernel_graph); - auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); + auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get()); auto primitive = NewValueNode(std::make_shared(prim::kPrimReturn->name())); auto return_node = kernel_graph->NewCNode({primitive, parameter}); kernel_graph->set_return(return_node); @@ -806,7 +800,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con auto graph = NewKernelGraph(); MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Create graph: " << graph->graph_id(); - size_t from_other_graph_depend_num = 0; for (const auto &node : lst) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); @@ -816,16 +809,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); // create a new cnode object - bool from_other_graph = false; - // only first depend from other graph can create - bool valid_input = true; - if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { - valid_input = false; - } - auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { - from_other_graph_depend_num++; - } + auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode); MS_EXCEPTION_IF_NULL(new_cnode); new_cnode->set_abstract(cnode->abstract()); new_cnode->set_scope(cnode->scope()); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index c0e0ed8d22..d8483f13b7 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -100,7 +100,7 @@ class SessionBasic : public std::enable_shared_from_this { std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, std::vector *all_out_graph); - CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, + CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, std::unordered_map *other_graph_cnode); CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); @@ -153,11 +153,11 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &tensors_mask); // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); - std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); - virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + std::vector CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); + virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); - AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph); void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector &node_list);