|
|
|
@ -29,6 +29,7 @@
|
|
|
|
|
#include "device/ascend/ascend_kernel_runtime.h"
|
|
|
|
|
#include "device/ascend/ascend_device_address.h"
|
|
|
|
|
#include "pre_activate/ascend/ascend_backend_optimization.h"
|
|
|
|
|
#include "pre_activate/common/common_backend_optimization.h"
|
|
|
|
|
#include "device/kernel_adjust.h"
|
|
|
|
|
#include "device/ascend/ascend_stream_assign.h"
|
|
|
|
|
#include "device/ascend/ascend_label_assign.h"
|
|
|
|
@ -283,36 +284,38 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
|
|
|
|
|
|
|
|
|
|
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|
|
|
|
MS_LOG(INFO) << "start";
|
|
|
|
|
auto graph = ConstructKernelGraph(func_graph);
|
|
|
|
|
std::vector<KernelGraphPtr> all_graphs;
|
|
|
|
|
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
|
|
|
|
BackendOptimization(all_graphs);
|
|
|
|
|
// split switch
|
|
|
|
|
SplitGraphs(NOT_NULL(graph));
|
|
|
|
|
SplitGraphs(NOT_NULL(root_graph));
|
|
|
|
|
// insert goto labels and label_sets
|
|
|
|
|
LinkChildGraphs(NOT_NULL(graph));
|
|
|
|
|
LinkChildGraphs(NOT_NULL(root_graph));
|
|
|
|
|
// resource initialize
|
|
|
|
|
InitRuntimeResource();
|
|
|
|
|
// assign label
|
|
|
|
|
AssignLabel(NOT_NULL(graph));
|
|
|
|
|
// recurse compile child graph
|
|
|
|
|
AssignLabel(NOT_NULL(root_graph));
|
|
|
|
|
// recurse compile child root_graph
|
|
|
|
|
std::set<KernelGraphPtr> memo;
|
|
|
|
|
RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo));
|
|
|
|
|
// root graph valiate,include genearte execute order and so on
|
|
|
|
|
RootGraphExecutorValidate(NOT_NULL(graph));
|
|
|
|
|
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
|
|
|
|
|
// root root_graph valiate,include genearte execute order and so on
|
|
|
|
|
RootGraphExecutorValidate(NOT_NULL(root_graph));
|
|
|
|
|
// adjust kernel
|
|
|
|
|
AdjustKernel(graph);
|
|
|
|
|
AdjustKernel(root_graph);
|
|
|
|
|
// assign stream
|
|
|
|
|
AssignStream(graph);
|
|
|
|
|
AssignStream(root_graph);
|
|
|
|
|
// insert profiling point
|
|
|
|
|
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
|
|
|
|
|
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
|
|
|
|
|
// build kernel
|
|
|
|
|
BuildKernel(graph);
|
|
|
|
|
BuildKernel(root_graph);
|
|
|
|
|
// alloc mem
|
|
|
|
|
MemoryAlloc(graph.get());
|
|
|
|
|
MemoryAlloc(root_graph.get());
|
|
|
|
|
// task generate
|
|
|
|
|
GenerateTaskInfo(graph);
|
|
|
|
|
GenerateTaskInfo(root_graph);
|
|
|
|
|
// load task into device
|
|
|
|
|
LoadTask(graph);
|
|
|
|
|
// return the graph id to backend
|
|
|
|
|
auto graph_id = graph->graph_id();
|
|
|
|
|
LoadTask(root_graph);
|
|
|
|
|
// return the root_graph id to backend
|
|
|
|
|
auto graph_id = root_graph->graph_id();
|
|
|
|
|
return graph_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1569,6 +1572,14 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|
|
|
|
return call_node_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
|
|
|
|
|
MS_LOG(INFO) << "Start BackendCommonOptimization";
|
|
|
|
|
for (auto &graph : all_graphs) {
|
|
|
|
|
opt::BackendCommonOptimization(graph);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "End.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
std::set<KernelGraphPtr> memo;
|
|
|
|
|
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
|
|
|
|
|