|
|
|
@ -447,6 +447,37 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|
|
|
|
return new_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
// create primitive of cnode:call(partial or switch)
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs = {
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input);
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
|
|
|
|
if (cnode_input == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
|
|
|
|
|
<< ", but input[0] has not been created.";
|
|
|
|
|
}
|
|
|
|
|
// if the node is partial, insert the inputs of partial to the call
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
|
|
|
|
|
auto partial_node = attr_input->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node);
|
|
|
|
|
auto partial_inputs = partial_node->inputs();
|
|
|
|
|
std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
|
|
|
|
|
std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
|
|
|
|
|
return graph->GetBackendAnfByFrontAnf(node);
|
|
|
|
|
});
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
|
|
|
|
cnode_inputs.emplace_back(cnode_input);
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
@ -471,18 +502,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (attr_input->isa<CNode>()) {
|
|
|
|
|
// create primitive of cnode:call(switch)
|
|
|
|
|
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
|
|
|
|
if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
|
|
|
|
|
}
|
|
|
|
|
cnode_inputs.emplace_back(cnode_input);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
|
|
|
|
|
<< ", but input[0] has not been created.";
|
|
|
|
|
}
|
|
|
|
|
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
|
|
|
|
|
} else {
|
|
|
|
|
// get primitive of old node
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
|