|
|
|
@ -928,9 +928,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) {
|
|
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_input);
|
|
|
|
|
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node.";
|
|
|
|
|
}
|
|
|
|
@ -940,24 +939,37 @@ void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const Anf
|
|
|
|
|
auto ret = partial_kernel_graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ret);
|
|
|
|
|
auto return_input = ret->input(kFirstDataInputIndex);
|
|
|
|
|
// if kernel graph return node is a function
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
|
|
|
|
// return node is a function
|
|
|
|
|
std::vector<AnfNodePtr> call_inputs = {
|
|
|
|
|
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
AnfNodePtr real_kernel_graph;
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
|
|
|
|
auto return_input_cnode = return_input->cast<CNodePtr>();
|
|
|
|
|
|
|
|
|
|
auto partial_inputs = return_input_cnode->inputs();
|
|
|
|
|
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
|
|
|
|
|
real_kernel_graph = partial_inputs[kFirstDataInputIndex];
|
|
|
|
|
} else { // return node is kernel graph
|
|
|
|
|
call_inputs.emplace_back(return_input);
|
|
|
|
|
real_kernel_graph = return_input;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// new call node inputs
|
|
|
|
|
for (auto real_input : real_inputs) {
|
|
|
|
|
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get());
|
|
|
|
|
call_inputs.emplace_back(parameter_for_input);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
|
|
|
|
|
// update abstract
|
|
|
|
|
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_kernel_graph);
|
|
|
|
|
if (real_kernel_graph->isa<ValueNode>() && IsValueNode<FuncGraph>(real_kernel_graph)) {
|
|
|
|
|
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(real_kernel_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph);
|
|
|
|
|
auto ret_partial = sub_partial_kernel_graph->get_return();
|
|
|
|
|
call_node->set_abstract(ret_partial->abstract());
|
|
|
|
|
}
|
|
|
|
|
// update return input
|
|
|
|
|
ret->set_input(kFirstDataInputIndex, call_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
@ -977,9 +989,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|
|
|
|
auto node = make_tuple_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto make_tuple_inputs = node->inputs();
|
|
|
|
|
// there is real input in call, should put it to make_tuple in switch_layer
|
|
|
|
|
auto real_input = cnode->input(kFirstDataInputIndex);
|
|
|
|
|
auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input);
|
|
|
|
|
// there are real inputs in call, should put it to make_tuple in switch_layer
|
|
|
|
|
std::vector<AnfNodePtr> real_inputs;
|
|
|
|
|
for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
|
|
|
|
|
real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> new_make_tuple_inputs = {
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
|
|
|
|
|
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
|
|
|
|
@ -990,10 +1004,18 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|
|
|
|
auto partial_node = partial_idx->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node);
|
|
|
|
|
// update kernel graph when switch_layer node return function
|
|
|
|
|
CreateCallNodeReturnFunction(partial_node, real_input_back);
|
|
|
|
|
auto partial_input = partial_node->input(kFirstDataInputIndex);
|
|
|
|
|
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
|
|
|
|
|
auto ret = partial_kernel_graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ret);
|
|
|
|
|
auto return_input = ret->input(kFirstDataInputIndex);
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode<KernelGraph>(return_input)) {
|
|
|
|
|
CreateCallNodeReturnFunction(partial_node, real_inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs();
|
|
|
|
|
new_partial_inputs.emplace_back(real_input_back);
|
|
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
|
|
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs);
|
|
|
|
|
new_make_tuple_inputs.emplace_back(new_partial);
|
|
|
|
|
}
|
|
|
|
@ -1003,7 +1025,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|
|
|
|
std::vector<AnfNodePtr> new_partial_inputs;
|
|
|
|
|
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
|
|
|
|
|
new_partial_inputs.emplace_back(partial_idx);
|
|
|
|
|
new_partial_inputs.emplace_back(real_input_back);
|
|
|
|
|
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
|
|
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs);
|
|
|
|
|
new_make_tuple_inputs.emplace_back(new_partial);
|
|
|
|
|
}
|
|
|
|
|