From d238682d6826f17b174839c261e362442c6c84d8 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Thu, 11 Jun 2020 14:36:37 +0800 Subject: [PATCH] bind summary nodes to graph --- mindspore/ccsrc/session/ascend_session.cc | 29 +++++++++++++++++------ mindspore/ccsrc/session/ascend_session.h | 4 ++-- mindspore/ccsrc/session/cpu_session.cc | 3 ++- mindspore/ccsrc/session/kernel_graph.h | 3 +++ mindspore/ccsrc/session/session_basic.cc | 15 ++++++------ mindspore/ccsrc/session/session_basic.h | 5 ++-- 6 files changed, 39 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 6ba7ad4256..bf1fa87530 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -322,6 +322,18 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { return graph_id; } +void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph) { + auto graph_order = GetGraphOrder(kernel_graph->graph_id()); + for (auto graph_id : graph_order) { + auto child_graph = GetGraph(graph_id); + if (child_graph->summary_node_exist()) { + kernel_graph->set_summary_node_exist(true); + return; + } + } + kernel_graph->set_summary_node_exist(false); +} + void AscendSession::BuildGraph(GraphId graph_id) { MS_LOG(INFO) << "start"; auto graph = GetGraph(graph_id); @@ -337,6 +349,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { InsertAllAssigns(); // insert switch and active to child graph MergeSwitchCompile(); + SetFinalGraphSummaryFlag(graph); // OptChildGraphs auto graph_order = GetGraphOrder(final_graph_id_); auto &graph_type = GetGraphOrderType(final_graph_id_); @@ -348,6 +361,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { auto child_graph = GetGraph(graph_order[i]); CompileChildGraph(child_graph); } + GetSummaryNodes(graph.get()); // merge child graph MergeGraphExecOrder(); } else { @@ -751,25 +765,26 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { return final_graph_id_; } -void AscendSession::GetSummaryNodes(const KernelGraph *graph, - std::unordered_map> *summary) { +void AscendSession::GetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(summary); - summary->clear(); // if final graph have no child graph auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); if (graph_order_iter == graph_execute_orders_.end()) { - SessionBasic::GetSummaryNodes(graph, summary); + SessionBasic::GetSummaryNodes(graph); return; } // for every child graph, find summary nodes + auto summary = graph->summary_nodes(); auto graph_order = GetGraphOrder(graph->graph_id()); for (size_t i = 0; i < graph_order.size(); i++) { auto child_graph = GetGraph(graph_order[i]); - SessionBasic::GetSummaryNodes(child_graph.get(), summary); + SessionBasic::GetSummaryNodes(child_graph.get()); + auto child_graph_summary = child_graph->summary_nodes(); + summary.insert(child_graph_summary.begin(), child_graph_summary.end()); } - MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); + graph->set_summary_nodes(summary); + MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); } AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 0c31c4c77e..13ee80b254 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -67,8 +67,7 @@ class AscendSession : public SessionBasic { void SetActive(GraphId, GraphId) override; // compile child graph when session have multiple child graphs void CompileChildGraph(const KernelGraphPtr &child_graph); - void GetSummaryNodes(const KernelGraph *graph, - std::unordered_map> *summary) override; + void GetSummaryNodes(KernelGraph *graph) override; private: void InitRuntimeResource(); @@ -149,6 +148,7 @@ class AscendSession : public SessionBasic { AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); // sync intial tensors' data to device void SyncInitialTenosrToDevice(); + void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); // member variables // key is final_graph_id,value is child graph execute order of final graph diff --git a/mindspore/ccsrc/session/cpu_session.cc b/mindspore/ccsrc/session/cpu_session.cc index c3caf512ac..8d6bc0f2b9 100644 --- a/mindspore/ccsrc/session/cpu_session.cc +++ b/mindspore/ccsrc/session/cpu_session.cc @@ -73,7 +73,8 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vectorset_execution_order(execution_order); NamedSummaryOutputs summary_outputs; if (enable_summary) { - GetSummaryNodes(kernel_graph.get(), &summary_outputs); + GetSummaryNodes(kernel_graph.get()); + summary_outputs = kernel_graph->summary_nodes(); runtime_.IncreaseSummaryRefCount(summary_outputs); } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 2cd7a2340a..6653de7e55 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -142,6 +142,8 @@ class KernelGraph : public FuncGraph { bool get_output_null() { return null_output_; } void set_output_null(bool is_output_null) { null_output_ = is_output_null; } void PrintGraphExecuteOrder() const; + std::map> &summary_nodes() { return summary_nodes_; } + void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } private: // remove value node form graph @@ -175,6 +177,7 @@ class KernelGraph : public FuncGraph { // record map between ref final output anf with index and ref origin input with index std::map ref_out_in_map_; std::unordered_map>> node_output_edges_; + std::map> summary_nodes_; // graph needn't execute bool executable_; // exist summary node in graph diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index d11446a8ba..8124a0f3c8 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -745,13 +745,13 @@ void SessionBasic::Reorder(std::vector *node_list) { (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); } -void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { +void SessionBasic::GetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(summary); if (!graph->summary_node_exist()) { return; } + auto summary = graph->summary_nodes(); auto apply_list = TopoSort(graph->get_return()); for (auto &n : apply_list) { MS_EXCEPTION_IF_NULL(n); @@ -764,14 +764,15 @@ void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs } auto node = cnode->input(kSummaryGetItem); MS_EXCEPTION_IF_NULL(node); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); if (!AnfAlgo::IsRealKernel(item_with_index.first)) { MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); } - (*summary)[n->fullname_with_scope()] = item_with_index; + summary[n->fullname_with_scope()] = item_with_index; } } - MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); + graph->set_summary_nodes(summary); + MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); } void SessionBasic::Summary(KernelGraph *graph) { @@ -779,8 +780,8 @@ void SessionBasic::Summary(KernelGraph *graph) { return; } MS_EXCEPTION_IF_NULL(graph); - NamedSummaryOutputs summary_outputs; - GetSummaryNodes(graph, &summary_outputs); + GetSummaryNodes(graph); + auto summary_outputs = graph->summary_nodes(); // do not exist summary node if (summary_outputs.empty()) { return; diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 4620bc763d..f4f391d0fe 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -93,8 +93,7 @@ class SessionBasic { virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } virtual void SetActive(GraphId, GraphId) {} - virtual void GetSummaryNodes(const KernelGraph *graph, - std::unordered_map> *summary); + virtual void GetSummaryNodes(KernelGraph *graph); protected: virtual void LoadInputData(const std::shared_ptr &kernel_graph, @@ -130,7 +129,7 @@ class SessionBasic { }; using SessionPtr = std::shared_ptr; -using NamedSummaryOutputs = std::unordered_map>; +using NamedSummaryOutputs = std::map>; } // namespace session } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H