|
|
|
@ -148,7 +148,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
|
|
|
|
// if in paynative mode,data only copyed to host when user want to print data
|
|
|
|
|
// if in pynative mode,data only copied to host when user want to print data
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
|
|
|
@ -499,10 +499,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
|
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
|
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
|
|
|
|
|
auto parameter = graph->NewParameter();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
parameter->set_abstract(abstract);
|
|
|
|
|
auto new_parameter = graph->NewParameter(parameter);
|
|
|
|
|
auto new_parameter = graph->NewParameter(abstract);
|
|
|
|
|
parameters.push_back(new_parameter);
|
|
|
|
|
valid_inputs->push_back(true);
|
|
|
|
|
graph_inputs->push_back(new_parameter);
|
|
|
|
@ -662,7 +659,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
|
|
|
|
return new_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) {
|
|
|
|
|
CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
// switch input generalizes partial
|
|
|
|
@ -675,9 +672,11 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
|
|
|
|
|
} else {
|
|
|
|
|
KernelGraphPtr kernel_graph = NewKernelGraph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get());
|
|
|
|
|
auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
|
|
|
|
|
parameter->set_abstract(cnode->abstract());
|
|
|
|
|
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
|
|
|
|
|
auto return_node = kernel_graph->NewCNode({primitive, parameter});
|
|
|
|
|
return_node->set_abstract(cnode->abstract());
|
|
|
|
|
kernel_graph->set_return(return_node);
|
|
|
|
|
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
|
|
|
|
|
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
|
|
|
|
@ -722,10 +721,97 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_input);
|
|
|
|
|
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node.";
|
|
|
|
|
}
|
|
|
|
|
auto partial_input = cnode->input(kFirstDataInputIndex);
|
|
|
|
|
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
|
|
|
|
|
auto ret = partial_kernel_graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ret);
|
|
|
|
|
auto return_input = ret->input(kFirstDataInputIndex);
|
|
|
|
|
// if kernel graph return node is a function
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
|
|
|
|
std::vector<AnfNodePtr> call_inputs = {
|
|
|
|
|
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
auto return_input_cnode = return_input->cast<CNodePtr>();
|
|
|
|
|
|
|
|
|
|
auto partial_inputs = return_input_cnode->inputs();
|
|
|
|
|
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
|
|
|
|
|
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get());
|
|
|
|
|
call_inputs.emplace_back(parameter_for_input);
|
|
|
|
|
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
|
|
|
|
|
// update abstract
|
|
|
|
|
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]);
|
|
|
|
|
auto ret_partial = sub_partial_kernel_graph->get_return();
|
|
|
|
|
call_node->set_abstract(ret_partial->abstract());
|
|
|
|
|
// update return input
|
|
|
|
|
ret->set_input(kFirstDataInputIndex, call_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs = {
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input);
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
|
|
|
|
auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_layer_cnode);
|
|
|
|
|
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
|
|
|
|
|
switch_layer_cnode->input(kFirstDataInputIndex)};
|
|
|
|
|
auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
|
|
|
|
auto node = make_tuple_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto make_tuple_inputs = node->inputs();
|
|
|
|
|
// there is real input in call, should put it to make_tuple in switch_layer
|
|
|
|
|
auto real_input = cnode->input(kFirstDataInputIndex);
|
|
|
|
|
auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input);
|
|
|
|
|
std::vector<AnfNodePtr> new_make_tuple_inputs = {
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
|
|
|
|
|
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
|
|
|
|
|
auto partial_idx = make_tuple_inputs[idx];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode->abstract());
|
|
|
|
|
// switch_layer node input is partial cnode
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
|
|
|
|
|
auto partial_node = partial_idx->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node);
|
|
|
|
|
// update kernel graph when switch_layer node return function
|
|
|
|
|
CreateCallNodeReturnFunction(partial_node, real_input_back);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs();
|
|
|
|
|
new_partial_inputs.emplace_back(real_input_back);
|
|
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs);
|
|
|
|
|
new_make_tuple_inputs.emplace_back(new_partial);
|
|
|
|
|
}
|
|
|
|
|
// switch_layer node input is kernel graph value node
|
|
|
|
|
if (IsValueNode<KernelGraph>(partial_idx)) {
|
|
|
|
|
// make_tuple inputs is KernelGraph
|
|
|
|
|
std::vector<AnfNodePtr> new_partial_inputs;
|
|
|
|
|
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
|
|
|
|
|
new_partial_inputs.emplace_back(partial_idx);
|
|
|
|
|
new_partial_inputs.emplace_back(real_input_back);
|
|
|
|
|
auto new_partial = graph->NewCNode(new_partial_inputs);
|
|
|
|
|
new_make_tuple_inputs.emplace_back(new_partial);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
|
|
|
|
|
switch_layer_inputs.emplace_back(new_make_tuple);
|
|
|
|
|
auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
|
|
|
|
|
cnode_inputs.emplace_back(new_switch_layer);
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
// create primitive of cnode:call(partial or switch)
|
|
|
|
|
// create primitive of cnode:call(partial or switch or switch_layer)
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs = {
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
|
|
|
@ -748,9 +834,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
|
|
|
|
return CreateCallSwitchInputs(cnode, graph);
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
|
|
|
|
|
return CreateCallSwitchLayerInputs(cnode, graph);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
|
|
|
|
|
<< "must be partial or switch.";
|
|
|
|
|
<< "must be partial or switch or switch_layer.";
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -788,7 +876,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|
|
|
|
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
|
|
|
|
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
|
|
|
|
|
auto node_input = cnode->input(index);
|
|
|
|
|
auto switch_input = CreateSwitchInput(node_input, graph);
|
|
|
|
|
auto switch_input = CreateSwitchInput(cnode, node_input, graph);
|
|
|
|
|
cnode_inputs->emplace_back(switch_input);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -841,10 +929,17 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
|
|
|
|
|
// if the cnode is call switch, remove call
|
|
|
|
|
if (new_cnode->inputs().size() > 1) {
|
|
|
|
|
auto first_input = new_cnode->input(kFirstDataInputIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(first_input);
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
|
|
|
|
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
|
|
|
|
|
new_cnode = first_input->cast<CNodePtr>();
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
|
|
|
|
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
|
|
|
|
|
auto abstract = cnode->abstract();
|
|
|
|
|
new_cnode = first_input->cast<CNodePtr>();
|
|
|
|
|
new_cnode->set_abstract(abstract);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return new_cnode;
|
|
|
|
@ -1842,7 +1937,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
// PS embeddingLookup cache check.
|
|
|
|
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in "
|
|
|
|
|
MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in "
|
|
|
|
|
"parameter server training mode.";
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
|
|
|
|
|