|
|
|
@ -20,25 +20,22 @@
|
|
|
|
|
|
|
|
|
|
#include "c_ops/primitive_c.h"
|
|
|
|
|
#include "ir/manager.h"
|
|
|
|
|
#include "ir/param_info.h"
|
|
|
|
|
#include "backend/kernel_compiler/common_utils.h"
|
|
|
|
|
#include "base/core_ops.h"
|
|
|
|
|
#include "common/trans.h"
|
|
|
|
|
#include "utils/config_manager.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "backend/session/executor.h"
|
|
|
|
|
#include "backend/session/executor_manager.h"
|
|
|
|
|
#include "backend/optimizer/common/common_backend_optimization.h"
|
|
|
|
|
#include "backend/optimizer/common/helper.h"
|
|
|
|
|
#include "runtime/device/kernel_runtime_manager.h"
|
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
#include "ir/dtype.h"
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "debug/anf_ir_dump.h"
|
|
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
|
|
|
|
#include "ps/worker.h"
|
|
|
|
|
#include "ps/common.h"
|
|
|
|
|
#include "ps/util.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -665,8 +662,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input);
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
|
|
|
|
if (cnode_input == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
|
|
|
|
|
<< ", but input[0] has not been created.";
|
|
|
|
|
MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
// if the node is partial, insert the inputs of partial to the call
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
|
|
|
|
@ -682,7 +679,9 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
|
|
|
|
return CreateCallSwitchInputs(cnode, graph);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
|
|
|
|
|
MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
|
|
|
|
|
<< "must be partial or switch.";
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
@ -752,6 +751,10 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
|
|
|
|
|
// 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
|
|
|
|
|
// 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
|
|
|
|
|
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
|
|
|
|
|
if (cnode_inputs.empty()) {
|
|
|
|
|
MS_LOG_ERROR << "Create switch or partial failed, cnode:" << cnode->DebugString();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// get primitive of old node
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
@ -877,14 +880,16 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
|
|
|
|
|
bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
// create a new cnode object
|
|
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
|
|
|
auto new_cnode = CreateNewCNode(cnode, graph);
|
|
|
|
|
if (new_cnode == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
new_cnode->set_abstract(cnode->abstract());
|
|
|
|
|
std::string fullname;
|
|
|
|
|
if (cnode->input(kAnfPrimitiveIndex)->isa<CNode>()) {
|
|
|
|
@ -898,6 +903,7 @@ void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
|
|
|
|
|
graph->set_return(new_cnode);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
|
|
|
@ -909,11 +915,10 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
front_backend_graph_map_[func_graph] = graph;
|
|
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
|
|
|
|
|
|
|
|
|
bool is_trace_back = false;
|
|
|
|
|
for (const auto &node : node_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
|
|
|
|
// Create parameter
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
auto graph_inputs = graph->MutableInputs();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
@ -921,25 +926,28 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|
|
|
|
graph_inputs->push_back(new_parameter);
|
|
|
|
|
graph->FrontBackendlMapAdd(node, new_parameter);
|
|
|
|
|
continue;
|
|
|
|
|
} else if (node->isa<ValueNode>()) {
|
|
|
|
|
}
|
|
|
|
|
// Create value node
|
|
|
|
|
if (node->isa<ValueNode>()) {
|
|
|
|
|
// Create common value node
|
|
|
|
|
if (!IsValueNode<FuncGraph>(node)) {
|
|
|
|
|
// if input is a common value node,
|
|
|
|
|
(void)CreateNewValueNode(node, graph.get());
|
|
|
|
|
} else {
|
|
|
|
|
// if input is a ValueNode<FuncGraph>
|
|
|
|
|
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
|
|
|
|
|
if (front_backend_graph_map_.find(child_graph) == front_backend_graph_map_.end()) {
|
|
|
|
|
(void)ConstructKernelGraph(child_graph, all_out_graph);
|
|
|
|
|
}
|
|
|
|
|
(void)CreateValueNodeKernelGraph(node, graph.get());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// Create child kernel graph according ValueNode<FuncGraph>
|
|
|
|
|
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
|
|
|
|
|
if (front_backend_graph_map_.find(child_graph) == front_backend_graph_map_.end()) {
|
|
|
|
|
(void)ConstructKernelGraph(child_graph, all_out_graph);
|
|
|
|
|
}
|
|
|
|
|
(void)CreateValueNodeKernelGraph(node, graph.get());
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
CreateCNodeKernelGraph(node, graph);
|
|
|
|
|
}
|
|
|
|
|
// Create cnode
|
|
|
|
|
if (!CreateCNodeOfKernelGraph(node, graph.get())) {
|
|
|
|
|
DumpIR("contruct_kernel_graph_fail.ir", func_graph);
|
|
|
|
|
MS_LOG_EXCEPTION << "construct func graph " << func_graph->ToString() << "fail!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
|
|
|
|
|
graph->set_output_null(is_trace_back);
|
|
|
|
|
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
|
|
|
|
|
graph->SetExecOrderByDefault();
|
|
|
|
|
if (ExistSummaryNode(graph.get())) {
|
|
|
|
|