|
|
|
@ -451,6 +451,126 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|
|
|
|
return new_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs;
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input);
|
|
|
|
|
if (IsValueNode<FuncGraph>(attr_input)) {
|
|
|
|
|
// create primitive of cnode:call
|
|
|
|
|
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
|
|
|
|
|
// create a ValueNode<KernelGraph> as input of cnode:call
|
|
|
|
|
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
|
|
|
|
|
} else {
|
|
|
|
|
auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
|
|
|
|
|
if (new_value_node != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(new_value_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (attr_input->isa<CNode>()) {
|
|
|
|
|
// create primitive of cnode:call(switch)
|
|
|
|
|
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
|
|
|
|
|
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
|
|
|
|
auto prim = GetCNodePrimitive(cnode_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim->name() != kSwitchOpName) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
|
|
|
|
|
}
|
|
|
|
|
cnode_inputs.emplace_back(cnode_input);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
|
|
|
|
|
<< ", but input[0] has not been created.";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// get primitive of old node
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
// push attr to inputs[0] of new cnode
|
|
|
|
|
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
|
|
|
|
auto anf = cnode->inputs()[input_idx];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
// anf has been created before
|
|
|
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<ValueNode>()) {
|
|
|
|
|
if (!IsValueNode<FuncGraph>(anf)) {
|
|
|
|
|
// if input is a common value node,
|
|
|
|
|
auto new_value_node = CreateNewValueNode(anf, graph);
|
|
|
|
|
if (new_value_node != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(new_value_node);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// if input is a ValueNode<FuncGraph>
|
|
|
|
|
auto new_value_node = CreateValueNodeKernelGraph(anf, graph);
|
|
|
|
|
if (new_value_node != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(new_value_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<Parameter>()) {
|
|
|
|
|
auto new_parameter = CreateNewParameter(anf, graph);
|
|
|
|
|
cnode_inputs.push_back(new_parameter);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
|
|
|
|
|
}
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
|
|
|
|
auto new_cnode = graph->NewCNode(cnode_inputs);
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
|
return new_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
auto value_node = anf->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub_func_graph);
|
|
|
|
|
if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
|
|
|
|
|
}
|
|
|
|
|
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
|
|
|
|
|
|
|
|
|
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
|
|
|
|
|
new_value_node->set_abstract(value_node->abstract());
|
|
|
|
|
// create new kernel_info of new value_node
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
|
kernel_info->SetFeatureMapFlag(false);
|
|
|
|
|
new_value_node->set_kernel_info(kernel_info);
|
|
|
|
|
// create kernel_build_info for new value node
|
|
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
|
|
|
|
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
|
|
|
|
|
|
|
|
|
|
graph->FrontBackendlMapAdd(anf, new_value_node);
|
|
|
|
|
graph->AddValueNodeToGraph(new_value_node);
|
|
|
|
|
|
|
|
|
|
return new_value_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
if (!anf->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
|
|
|
|
|
|
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
|
|
|
|
graph_inputs->push_back(new_parameter);
|
|
|
|
|
graph->FrontBackendlMapAdd(anf, new_parameter);
|
|
|
|
|
|
|
|
|
|
return new_parameter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
|
|
|
|
auto graph = NewKernelGraph();
|
|
|
|
@ -494,7 +614,69 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; }
|
|
|
|
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) {
|
|
|
|
|
MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph.";
|
|
|
|
|
return front_backend_graph_map_[func_graph];
|
|
|
|
|
}
|
|
|
|
|
auto node_list = TopoSort(func_graph->get_return());
|
|
|
|
|
auto graph = NewKernelGraph();
|
|
|
|
|
front_backend_graph_map_[func_graph] = graph;
|
|
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
|
|
|
|
|
|
|
|
|
for (const auto &node : node_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode";
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
|
|
|
|
|
// recurse control ops: call, partial
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input);
|
|
|
|
|
if (IsValueNode<FuncGraph>(attr_input)) {
|
|
|
|
|
// recurse call subgraph
|
|
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input);
|
|
|
|
|
ConstructKernelGraph(sub_func_graph);
|
|
|
|
|
} else if (IsValueNode<Primitive>(attr_input)) {
|
|
|
|
|
auto prim = GetCNodePrimitive(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim->name() == kPartialOpName) {
|
|
|
|
|
// recurse partial subgraph
|
|
|
|
|
auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_node);
|
|
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
|
|
|
|
|
ConstructKernelGraph(sub_func_graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create a new cnode object
|
|
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
|
|
|
new_cnode->set_abstract(cnode->abstract());
|
|
|
|
|
new_cnode->set_scope(cnode->scope());
|
|
|
|
|
graph->FrontBackendlMapAdd(node, new_cnode);
|
|
|
|
|
|
|
|
|
|
// set original return to kernel_graph
|
|
|
|
|
if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) {
|
|
|
|
|
graph->set_return(new_cnode);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_);
|
|
|
|
|
FuncGraphManagerPtr manager = context_->manager();
|
|
|
|
|
if (manager) {
|
|
|
|
|
manager->AddFuncGraph(graph);
|
|
|
|
|
graph->set_manager(manager);
|
|
|
|
|
}
|
|
|
|
|
graph->SetExecOrderByDefault();
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// run graph steps
|
|
|
|
|
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|
|
|
|