From 09d5a7227e0b0ac5ec907597b9d03b5859bf0588 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Mon, 15 Jun 2020 14:24:20 +0800 Subject: [PATCH] fix summary nodes memory reuse refcount --- .../ccsrc/pre_activate/mem_reuse/mem_reuse.cc | 26 +++++++++++++++++++ .../ccsrc/pre_activate/mem_reuse/mem_reuse.h | 1 + mindspore/ccsrc/session/kernel_graph.h | 2 +- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index ca4cbf5158..61d79b56d5 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -298,6 +298,31 @@ void MemReuseUtil::SetReuseRefCount() { } } +void MemReuseUtil::SetSummaryNodesRefCount() { + bool summary_exist = graph_->summary_node_exist(); + if (!summary_exist) { + return; + } + + auto summary_nodes = graph_->summary_nodes(); + if (summary_nodes.empty()) { + return; + } + + for (auto &node_item : summary_nodes) { + auto node = node_item.second.first; + size_t index = IntToSize(node_item.second.second); + MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; + if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { + KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; + kernel_ref->ref_count_ = kMaxRefCount; + kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; + } else { + MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope(); + } + } +} + void MemReuseUtil::SetGraphOutputRefCount() { auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { @@ -336,6 +361,7 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) { } SetKernelDefMap(); SetReuseRefCount(); + SetSummaryNodesRefCount(); SetWorkSpaceList(); #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h index 08029f231a..c7a129f1e9 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h @@ -63,6 +63,7 @@ class MemReuseUtil { void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); void SetKernelDefInputs(); void SetReuseRefCount(); + void SetSummaryNodesRefCount(); // Set the reference count of graph output specially. void SetGraphOutputRefCount(); // Reset the dynamic used reference count by ref_count_. diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 1a3817bb11..dbb79d561c 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -142,7 +142,7 @@ class KernelGraph : public FuncGraph { bool get_output_null() { return null_output_; } void set_output_null(bool is_output_null) { null_output_ = is_output_null; } void PrintGraphExecuteOrder() const; - std::map> &summary_nodes() { return summary_nodes_; } + const std::map> &summary_nodes() const { return summary_nodes_; } void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } private: