From ad8731944cf8925967dbc96f76d7fce741baa576 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Wed, 1 Jul 2020 10:31:55 +0800 Subject: [PATCH] handle partial node in CreateNewCNode --- mindspore/ccsrc/session/session_basic.cc | 44 +++++++++++++++++------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 8935d3df2f..8382d6af9b 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -447,6 +447,37 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K return new_cnode; } +static std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + // create primitive of cnode:call(partial or switch) + std::vector cnode_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + if (cnode_input == nullptr) { + MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() + << ", but input[0] has not been created."; + } + // if the node is partial, insert the inputs of partial to the call + if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) { + auto partial_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + auto partial_inputs = partial_node->inputs(); + std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(), + std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node)); + return graph->GetBackendAnfByFrontAnf(node); + }); + return cnode_inputs; + } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { + cnode_inputs.emplace_back(cnode_input); + return cnode_inputs; + } + MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; +} + CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); @@ -471,18 +502,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) } } } else if (attr_input->isa()) { - // create primitive of cnode:call(switch) - cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; - if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { - auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); - if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - MS_LOG(EXCEPTION) << "CNode input[0] must be switch."; - } - cnode_inputs.emplace_back(cnode_input); - } else { - MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() - << ", but input[0] has not been created."; - } + cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); } else { // get primitive of old node auto prim = AnfAlgo::GetCNodePrimitive(cnode);