!2058 bind summary nodes to KernelGraph in order to memory reuse

Merge pull request !2058 from Margaret_wangrui/summary
pull/2058/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6c26431ab3

@ -322,6 +322,18 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
return graph_id; return graph_id;
} }
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
for (auto graph_id : graph_order) {
auto child_graph = GetGraph(graph_id);
if (child_graph->summary_node_exist()) {
kernel_graph->set_summary_node_exist(true);
return;
}
}
kernel_graph->set_summary_node_exist(false);
}
void AscendSession::BuildGraph(GraphId graph_id) { void AscendSession::BuildGraph(GraphId graph_id) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
auto graph = GetGraph(graph_id); auto graph = GetGraph(graph_id);
@ -337,6 +349,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
InsertAllAssigns(); InsertAllAssigns();
// insert switch and active to child graph // insert switch and active to child graph
MergeSwitchCompile(); MergeSwitchCompile();
SetFinalGraphSummaryFlag(graph);
// OptChildGraphs // OptChildGraphs
auto graph_order = GetGraphOrder(final_graph_id_); auto graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_); auto &graph_type = GetGraphOrderType(final_graph_id_);
@ -348,6 +361,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
auto child_graph = GetGraph(graph_order[i]); auto child_graph = GetGraph(graph_order[i]);
CompileChildGraph(child_graph); CompileChildGraph(child_graph);
} }
GetSummaryNodes(graph.get());
// merge child graph // merge child graph
MergeGraphExecOrder(); MergeGraphExecOrder();
} else { } else {
@ -751,25 +765,26 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
return final_graph_id_; return final_graph_id_;
} }
void AscendSession::GetSummaryNodes(const KernelGraph *graph, void AscendSession::GetSummaryNodes(KernelGraph *graph) {
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_LOG(DEBUG) << "Update summary Start"; MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
summary->clear();
// if final graph have no child graph // if final graph have no child graph
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
if (graph_order_iter == graph_execute_orders_.end()) { if (graph_order_iter == graph_execute_orders_.end()) {
SessionBasic::GetSummaryNodes(graph, summary); SessionBasic::GetSummaryNodes(graph);
return; return;
} }
// for every child graph, find summary nodes // for every child graph, find summary nodes
auto summary = graph->summary_nodes();
auto graph_order = GetGraphOrder(graph->graph_id()); auto graph_order = GetGraphOrder(graph->graph_id());
for (size_t i = 0; i < graph_order.size(); i++) { for (size_t i = 0; i < graph_order.size(); i++) {
auto child_graph = GetGraph(graph_order[i]); auto child_graph = GetGraph(graph_order[i]);
SessionBasic::GetSummaryNodes(child_graph.get(), summary); SessionBasic::GetSummaryNodes(child_graph.get());
auto child_graph_summary = child_graph->summary_nodes();
summary.insert(child_graph_summary.begin(), child_graph_summary.end());
} }
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); graph->set_summary_nodes(summary);
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
} }
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {

@ -67,8 +67,7 @@ class AscendSession : public SessionBasic {
void SetActive(GraphId, GraphId) override; void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs // compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph); void CompileChildGraph(const KernelGraphPtr &child_graph);
void GetSummaryNodes(const KernelGraph *graph, void GetSummaryNodes(KernelGraph *graph) override;
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) override;
private: private:
void InitRuntimeResource(); void InitRuntimeResource();
@ -149,6 +148,7 @@ class AscendSession : public SessionBasic {
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
// sync intial tensors' data to device // sync intial tensors' data to device
void SyncInitialTenosrToDevice(); void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
// member variables // member variables
// key is final_graph_id,value is child graph execute order of final graph // key is final_graph_id,value is child graph execute order of final graph

@ -73,7 +73,8 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
kernel_graph->set_execution_order(execution_order); kernel_graph->set_execution_order(execution_order);
NamedSummaryOutputs summary_outputs; NamedSummaryOutputs summary_outputs;
if (enable_summary) { if (enable_summary) {
GetSummaryNodes(kernel_graph.get(), &summary_outputs); GetSummaryNodes(kernel_graph.get());
summary_outputs = kernel_graph->summary_nodes();
runtime_.IncreaseSummaryRefCount(summary_outputs); runtime_.IncreaseSummaryRefCount(summary_outputs);
} }

@ -142,6 +142,8 @@ class KernelGraph : public FuncGraph {
bool get_output_null() { return null_output_; } bool get_output_null() { return null_output_; }
void set_output_null(bool is_output_null) { null_output_ = is_output_null; } void set_output_null(bool is_output_null) { null_output_ = is_output_null; }
void PrintGraphExecuteOrder() const; void PrintGraphExecuteOrder() const;
std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() { return summary_nodes_; }
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
private: private:
// remove value node form graph // remove value node form graph
@ -175,6 +177,7 @@ class KernelGraph : public FuncGraph {
// record map between ref final output anf with index and ref origin input with index // record map between ref final output anf with index and ref origin input with index
std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
// graph needn't execute // graph needn't execute
bool executable_; bool executable_;
// exist summary node in graph // exist summary node in graph

@ -745,13 +745,13 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
} }
void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start"; MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
if (!graph->summary_node_exist()) { if (!graph->summary_node_exist()) {
return; return;
} }
auto summary = graph->summary_nodes();
auto apply_list = TopoSort(graph->get_return()); auto apply_list = TopoSort(graph->get_return());
for (auto &n : apply_list) { for (auto &n : apply_list) {
MS_EXCEPTION_IF_NULL(n); MS_EXCEPTION_IF_NULL(n);
@ -764,14 +764,15 @@ void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs
} }
auto node = cnode->input(kSummaryGetItem); auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) { if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
} }
(*summary)[n->fullname_with_scope()] = item_with_index; summary[n->fullname_with_scope()] = item_with_index;
} }
} }
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); graph->set_summary_nodes(summary);
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
} }
void SessionBasic::Summary(KernelGraph *graph) { void SessionBasic::Summary(KernelGraph *graph) {
@ -779,8 +780,8 @@ void SessionBasic::Summary(KernelGraph *graph) {
return; return;
} }
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
NamedSummaryOutputs summary_outputs; GetSummaryNodes(graph);
GetSummaryNodes(graph, &summary_outputs); auto summary_outputs = graph->summary_nodes();
// do not exist summary node // do not exist summary node
if (summary_outputs.empty()) { if (summary_outputs.empty()) {
return; return;

@ -93,8 +93,7 @@ class SessionBasic {
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
virtual void SetActive(GraphId, GraphId) {} virtual void SetActive(GraphId, GraphId) {}
virtual void GetSummaryNodes(const KernelGraph *graph, virtual void GetSummaryNodes(KernelGraph *graph);
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary);
protected: protected:
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
@ -130,7 +129,7 @@ class SessionBasic {
}; };
using SessionPtr = std::shared_ptr<session::SessionBasic>; using SessionPtr = std::shared_ptr<session::SessionBasic>;
using NamedSummaryOutputs = std::unordered_map<std::string, std::pair<AnfNodePtr, int>>; using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>;
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H

Loading…
Cancel
Save