|
|
@ -298,6 +298,31 @@ void MemReuseUtil::SetReuseRefCount() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void MemReuseUtil::SetSummaryNodesRefCount() {
|
|
|
|
|
|
|
|
bool summary_exist = graph_->summary_node_exist();
|
|
|
|
|
|
|
|
if (!summary_exist) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto summary_nodes = graph_->summary_nodes();
|
|
|
|
|
|
|
|
if (summary_nodes.empty()) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &node_item : summary_nodes) {
|
|
|
|
|
|
|
|
auto node = node_item.second.first;
|
|
|
|
|
|
|
|
size_t index = IntToSize(node_item.second.second);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
|
|
|
|
|
|
|
|
if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) {
|
|
|
|
|
|
|
|
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
|
|
|
|
|
|
|
|
kernel_ref->ref_count_ = kMaxRefCount;
|
|
|
|
|
|
|
|
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MemReuseUtil::SetGraphOutputRefCount() {
|
|
|
|
void MemReuseUtil::SetGraphOutputRefCount() {
|
|
|
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
|
|
|
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
for (const auto &node : nodes) {
|
|
|
@ -336,6 +361,7 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
SetKernelDefMap();
|
|
|
|
SetKernelDefMap();
|
|
|
|
SetReuseRefCount();
|
|
|
|
SetReuseRefCount();
|
|
|
|
|
|
|
|
SetSummaryNodesRefCount();
|
|
|
|
SetWorkSpaceList();
|
|
|
|
SetWorkSpaceList();
|
|
|
|
#ifdef MEM_REUSE_DEBUG
|
|
|
|
#ifdef MEM_REUSE_DEBUG
|
|
|
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
|
|
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
|
|
|