|
|
|
@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
|
|
return new_value_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
|
|
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
if (!anf->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
|
|
|
|
}
|
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
|
auto valid_inputs = graph->MutableValidInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs);
|
|
|
|
|
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
|
|
|
|
graph->FrontBackendlMapAdd(anf, new_parameter);
|
|
|
|
|
graph_inputs->push_back(new_parameter);
|
|
|
|
|
valid_inputs->push_back(valid_input);
|
|
|
|
|
return new_parameter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
|
|
|
|
|
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> parameters;
|
|
|
|
|
std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
|
|
|
|
|
auto valid_inputs = graph->MutableValidInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs);
|
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
|
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
|
|
|
|
|
auto parameter = graph->NewParameter();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
parameter->set_abstract(abstract);
|
|
|
|
|
parameters.push_back(graph->NewParameter(parameter));
|
|
|
|
|
auto new_parameter = graph->NewParameter(parameter);
|
|
|
|
|
parameters.push_back(new_parameter);
|
|
|
|
|
valid_inputs->push_back(valid_input);
|
|
|
|
|
graph_inputs->push_back(new_parameter);
|
|
|
|
|
};
|
|
|
|
|
for (const auto &out_node : pre_graph_out) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out_node);
|
|
|
|
@ -287,18 +296,15 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelG
|
|
|
|
|
return parameters;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
|
|
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
if (!anf->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]";
|
|
|
|
|
auto parameters = CreateParameterFromTuple(anf, graph);
|
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
|
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs));
|
|
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
|
|
|
|
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
|
|
|
|
if (parameters.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "no parameter exist!!";
|
|
|
|
|
MS_LOG(EXCEPTION) << "No parameter exist!!";
|
|
|
|
|
}
|
|
|
|
|
if (parameters.size() == 1) {
|
|
|
|
|
return parameters[0];
|
|
|
|
@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph
|
|
|
|
|
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
|
|
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters";
|
|
|
|
|
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
|
|
|
|
|
|
|
|
|
|
GraphId SessionBasic::graph_sum_ = 0;
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
|
|
|
|
|
bool *from_other_graph,
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(from_other_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
|
|
|
|
*from_other_graph = false;
|
|
|
|
|
// get primitive of old node
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
// push attr to inputs[0] of new cnode
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
|
|
|
|
|
// if has multiple depends,only select first depend as parameter
|
|
|
|
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
|
|
|
|
auto anf = cnode->inputs()[input_idx];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|
|
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
|
|
|
|
continue;
|
|
|
|
|
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
|
|
|
|
cnode_inputs.push_back((*other_graph_cnode)[anf]);
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
|
|
|
|
|
// if input is a value node,
|
|
|
|
|
auto new_value_node = CreateNewValueNode(anf, graph);
|
|
|
|
@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<Parameter>()) {
|
|
|
|
|
// if anf is a parameter
|
|
|
|
|
cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph));
|
|
|
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
|
|
|
|
|
cnode_inputs.push_back(new_parameter);
|
|
|
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
|
|
|
|
|
graph->FrontBackendlMapAdd(anf, new_parameter);
|
|
|
|
|
} else {
|
|
|
|
|
(*other_graph_cnode)[anf] = new_parameter;
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<CNode>()) {
|
|
|
|
|
*from_other_graph = true;
|
|
|
|
|
// the input node is a cnode from other graph
|
|
|
|
|
cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph));
|
|
|
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
|
|
|
|
|
cnode_inputs.push_back(parameter_from_cnode);
|
|
|
|
|
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
|
|
|
|
|
}
|
|
|
|
|
return graph->NewCNode(cnode_inputs);
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
|
|
|
|
auto new_cnode = graph->NewCNode(cnode_inputs);
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
|
return new_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
|
|
|
|
auto graph = std::make_shared<KernelGraph>();
|
|
|
|
|
graph->set_graph_id(graph_sum_);
|
|
|
|
|
MS_LOG(INFO) << "Create graph: " << graph_sum_;
|
|
|
|
|
size_t from_other_graph_depend_num = 0;
|
|
|
|
|
for (const auto &node : lst) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString();
|
|
|
|
|
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
|
|
|
|
// create a new cnode object
|
|
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
|
|
|
|
bool from_other_graph = false;
|
|
|
|
|
// only first depend from other graph can create
|
|
|
|
|
bool valid_input = true;
|
|
|
|
|
if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
|
|
|
|
valid_input = false;
|
|
|
|
|
}
|
|
|
|
|
auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
|
|
|
|
|
from_other_graph_depend_num++;
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
|
|
|
new_cnode->set_abstract(cnode->abstract());
|
|
|
|
|
new_cnode->set_scope(cnode->scope());
|
|
|
|
|
// record map relations between anf from ME and new anf node used in backend
|
|
|
|
|
graph->FrontBackendlMapAdd(node, new_cnode);
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
|
}
|
|
|
|
|
// add a make_tuple at the end of graph as output
|
|
|
|
|
graph->set_output(ConstructOutput(outputs, graph));
|
|
|
|
@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor:
|
|
|
|
|
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> output_args;
|
|
|
|
|
auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr {
|
|
|
|
|
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
|
|
|
|
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
|
|
|
|
if (backend_anf != nullptr) {
|
|
|
|
|
return backend_anf;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!";
|
|
|
|
|
for (const auto &output : outputs) {
|
|
|
|
|
MS_LOG(INFO) << "output:" << output->DebugString();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
|
|
|
|
|
};
|
|
|
|
|
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
(void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
|
|
|
|
|