|
|
@ -691,47 +691,39 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
|
|
|
|
AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
std::vector<AnfNodePtr> parameters;
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> pre_graph_out = {node};
|
|
|
|
|
|
|
|
if (IgnoreCreateParameterForMakeTuple(node)) {
|
|
|
|
if (IgnoreCreateParameterForMakeTuple(node)) {
|
|
|
|
pre_graph_out.clear();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
|
|
|
|
|
|
|
|
auto parameters = AnfAlgo::GetAllOutput(new_parameter);
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> pre_graph_out = {node};
|
|
|
|
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
|
|
|
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
|
|
|
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
|
|
|
|
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
|
|
|
|
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
|
|
|
|
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for (const auto ¶meter : parameters) {
|
|
|
|
auto valid_inputs = graph->MutableValidInputs();
|
|
|
|
auto valid_inputs = graph->MutableValidInputs();
|
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs);
|
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs);
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
|
|
|
|
|
|
|
|
auto new_parameter = graph->NewParameter(abstract);
|
|
|
|
|
|
|
|
parameters.push_back(new_parameter);
|
|
|
|
|
|
|
|
valid_inputs->push_back(true);
|
|
|
|
valid_inputs->push_back(true);
|
|
|
|
graph_inputs->push_back(new_parameter);
|
|
|
|
graph_inputs->push_back(parameter);
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t param_index = 0;
|
|
|
|
for (const auto &out_node : pre_graph_out) {
|
|
|
|
for (const auto &out_node : pre_graph_out) {
|
|
|
|
MS_EXCEPTION_IF_NULL(out_node);
|
|
|
|
size_t output_size = AnfAlgo::GetOutputTensorNum(out_node);
|
|
|
|
auto abstract = out_node->abstract();
|
|
|
|
for (size_t i = 0; i < output_size; i++) {
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
if (param_index >= parameters.size()) {
|
|
|
|
// create multiple parameters if is a tuple output real kernel
|
|
|
|
MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
|
|
|
|
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
|
|
|
<< ",out_node:" << out_node->DebugString();
|
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
|
|
|
|
|
|
|
|
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
|
|
|
|
|
|
|
|
create_parameter((*tuple_abstract)[output_idx]);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
InitInternalOutputParameter(out_node, parameters[param_index++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// create single parameter if is a abstract real kernel
|
|
|
|
|
|
|
|
create_parameter(out_node->abstract());
|
|
|
|
|
|
|
|
InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return parameters;
|
|
|
|
return new_parameter;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
|
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
@ -770,20 +762,7 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
|
|
|
auto parameters = CreateParameterFromTuple(anf, graph);
|
|
|
|
return CreateParameterFromTuple(anf, graph);
|
|
|
|
if (parameters.empty()) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Empty parameter from cnode";
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parameters.size() == 1) {
|
|
|
|
|
|
|
|
return parameters[0];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
|
|
|
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
|
|
|
|
|
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_input);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
|
|
|
|
|
|
|
|
return make_tuple;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) {
|
|
|
|
void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) {
|
|
|
@ -884,6 +863,7 @@ CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr
|
|
|
|
KernelGraphPtr kernel_graph = NewKernelGraph();
|
|
|
|
KernelGraphPtr kernel_graph = NewKernelGraph();
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
|
|
|
|
auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
parameter->set_abstract(cnode->abstract());
|
|
|
|
parameter->set_abstract(cnode->abstract());
|
|
|
|
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
|
|
|
|
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
|
|
|
|
auto return_node = kernel_graph->NewCNode({primitive, parameter});
|
|
|
|
auto return_node = kernel_graph->NewCNode({primitive, parameter});
|
|
|
|