|
|
@ -322,6 +322,18 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|
|
|
return graph_id;
|
|
|
|
return graph_id;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|
|
|
|
|
|
|
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
|
|
|
|
|
|
|
|
for (auto graph_id : graph_order) {
|
|
|
|
|
|
|
|
auto child_graph = GetGraph(graph_id);
|
|
|
|
|
|
|
|
if (child_graph->summary_node_exist()) {
|
|
|
|
|
|
|
|
kernel_graph->set_summary_node_exist(true);
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
kernel_graph->set_summary_node_exist(false);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::BuildGraph(GraphId graph_id) {
|
|
|
|
void AscendSession::BuildGraph(GraphId graph_id) {
|
|
|
|
MS_LOG(INFO) << "start";
|
|
|
|
MS_LOG(INFO) << "start";
|
|
|
|
auto graph = GetGraph(graph_id);
|
|
|
|
auto graph = GetGraph(graph_id);
|
|
|
@ -337,6 +349,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|
|
|
InsertAllAssigns();
|
|
|
|
InsertAllAssigns();
|
|
|
|
// insert switch and active to child graph
|
|
|
|
// insert switch and active to child graph
|
|
|
|
MergeSwitchCompile();
|
|
|
|
MergeSwitchCompile();
|
|
|
|
|
|
|
|
SetFinalGraphSummaryFlag(graph);
|
|
|
|
// OptChildGraphs
|
|
|
|
// OptChildGraphs
|
|
|
|
auto graph_order = GetGraphOrder(final_graph_id_);
|
|
|
|
auto graph_order = GetGraphOrder(final_graph_id_);
|
|
|
|
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
|
|
|
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
|
|
@ -348,6 +361,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|
|
|
auto child_graph = GetGraph(graph_order[i]);
|
|
|
|
auto child_graph = GetGraph(graph_order[i]);
|
|
|
|
CompileChildGraph(child_graph);
|
|
|
|
CompileChildGraph(child_graph);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
GetSummaryNodes(graph.get());
|
|
|
|
// merge child graph
|
|
|
|
// merge child graph
|
|
|
|
MergeGraphExecOrder();
|
|
|
|
MergeGraphExecOrder();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -751,25 +765,26 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
|
|
|
return final_graph_id_;
|
|
|
|
return final_graph_id_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::GetSummaryNodes(const KernelGraph *graph,
|
|
|
|
void AscendSession::GetSummaryNodes(KernelGraph *graph) {
|
|
|
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Update summary Start";
|
|
|
|
MS_LOG(DEBUG) << "Update summary Start";
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(summary);
|
|
|
|
|
|
|
|
summary->clear();
|
|
|
|
|
|
|
|
// if final graph have no child graph
|
|
|
|
// if final graph have no child graph
|
|
|
|
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
|
|
|
|
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
|
|
|
|
if (graph_order_iter == graph_execute_orders_.end()) {
|
|
|
|
if (graph_order_iter == graph_execute_orders_.end()) {
|
|
|
|
SessionBasic::GetSummaryNodes(graph, summary);
|
|
|
|
SessionBasic::GetSummaryNodes(graph);
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// for every child graph, find summary nodes
|
|
|
|
// for every child graph, find summary nodes
|
|
|
|
|
|
|
|
auto summary = graph->summary_nodes();
|
|
|
|
auto graph_order = GetGraphOrder(graph->graph_id());
|
|
|
|
auto graph_order = GetGraphOrder(graph->graph_id());
|
|
|
|
for (size_t i = 0; i < graph_order.size(); i++) {
|
|
|
|
for (size_t i = 0; i < graph_order.size(); i++) {
|
|
|
|
auto child_graph = GetGraph(graph_order[i]);
|
|
|
|
auto child_graph = GetGraph(graph_order[i]);
|
|
|
|
SessionBasic::GetSummaryNodes(child_graph.get(), summary);
|
|
|
|
SessionBasic::GetSummaryNodes(child_graph.get());
|
|
|
|
|
|
|
|
auto child_graph_summary = child_graph->summary_nodes();
|
|
|
|
|
|
|
|
summary.insert(child_graph_summary.begin(), child_graph_summary.end());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
|
|
|
graph->set_summary_nodes(summary);
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
|
|
|
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
|
|
|