fix summary nodes memory reuse refcount

pull/2076/head
laiyongqiang 5 years ago
parent 6c26431ab3
commit 09d5a7227e

@ -298,6 +298,31 @@ void MemReuseUtil::SetReuseRefCount() {
} }
} }
void MemReuseUtil::SetSummaryNodesRefCount() {
bool summary_exist = graph_->summary_node_exist();
if (!summary_exist) {
return;
}
auto summary_nodes = graph_->summary_nodes();
if (summary_nodes.empty()) {
return;
}
for (auto &node_item : summary_nodes) {
auto node = node_item.second.first;
size_t index = IntToSize(node_item.second.second);
MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) {
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
kernel_ref->ref_count_ = kMaxRefCount;
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
} else {
MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope();
}
}
}
void MemReuseUtil::SetGraphOutputRefCount() { void MemReuseUtil::SetGraphOutputRefCount() {
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
for (const auto &node : nodes) { for (const auto &node : nodes) {
@ -336,6 +361,7 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
} }
SetKernelDefMap(); SetKernelDefMap();
SetReuseRefCount(); SetReuseRefCount();
SetSummaryNodesRefCount();
SetWorkSpaceList(); SetWorkSpaceList();
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);

@ -63,6 +63,7 @@ class MemReuseUtil {
void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
void SetKernelDefInputs(); void SetKernelDefInputs();
void SetReuseRefCount(); void SetReuseRefCount();
void SetSummaryNodesRefCount();
// Set the reference count of graph output specially. // Set the reference count of graph output specially.
void SetGraphOutputRefCount(); void SetGraphOutputRefCount();
// Reset the dynamic used reference count by ref_count_. // Reset the dynamic used reference count by ref_count_.

@ -142,7 +142,7 @@ class KernelGraph : public FuncGraph {
bool get_output_null() { return null_output_; } bool get_output_null() { return null_output_; }
void set_output_null(bool is_output_null) { null_output_ = is_output_null; } void set_output_null(bool is_output_null) { null_output_ = is_output_null; }
void PrintGraphExecuteOrder() const; void PrintGraphExecuteOrder() const;
std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() { return summary_nodes_; } const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
private: private:

Loading…
Cancel
Save