diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc index ebd23948..41f24b94 100755 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -1121,7 +1121,6 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, } } reusable_block->continuous_block_ = continuous; - reusable_block->ref_count_++; reusable_blocks_[memory_type][stream_id].erase((++it).base()); return reusable_block; } @@ -1136,7 +1135,6 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); block->AddNodeTypeIndex({n, mem_type, out_index, false, continuous_life_begin_}, real_size, no_align_size); block->stream_id_ = node_op_desc->GetStreamId(); - block->ref_count_++; block->continuous_block_ = continuous; block->batch_label_ = batch_label; if (mem_type == kOutput) { @@ -1266,6 +1264,7 @@ Status BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vectorfirst_continuous_block_ = true; block->last_continuous_block_ = true; + ++(block->ref_count_); } else { GELOGE(INTERNAL_ERROR, "node apply continuous output memory failed. node_name:%s", n->GetName().c_str()); return INTERNAL_ERROR; @@ -1289,6 +1288,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, return nullptr, "Get no align size failed"); std::string symbol; + bool reuse_input = false; if (IsSymbolExist(node_index_io, symbol)) { block = symbol_blocks_[symbol]; GE_IF_BOOL_EXEC(block == nullptr, GELOGE(FAILED, "Node %s ref block is nullptr.", node_op_desc->GetName().c_str()); @@ -1303,6 +1303,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, block->SetLifeTimeEnd(life_time_); block->AddNodeTypeIndex({n, kOutput, index, true, continuous_life_begin_}, size, no_align_size); block->ref_count_++; + reuse_input = true; // add new size align_size = block_size; @@ -1336,7 +1337,6 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, workspace_reuse_flag, is_op_reuse_mem, continuous, memory_type); } GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "Block is nullptr."); - int out_count_reuse_input = block->ref_count_; int out_count = 0; GE_IF_BOOL_EXEC(index >= n->GetAllOutDataAnchors().size(), GELOGE(FAILED, "index is out of range."); return nullptr); auto out_data_anchor = n->GetOutDataAnchor(index); @@ -1351,28 +1351,8 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, out_count++; } } - bool reuse_input = false; - for (const auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - auto owner_node = in_anchor->GetOwnerNode(); - GE_IF_BOOL_EXEC(owner_node == nullptr, continue); - auto op_desc = owner_node->GetOpDesc(); - GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - for (uint32_t i = 0; i < static_cast(op_desc->GetOutputsSize()); i++) { - bool dst_reuse_input = false; - uint32_t dst_reuse_input_index = 0; - auto owner_node_op_desc = op_desc->GetOutputDescPtr(i); - GE_IF_BOOL_EXEC(owner_node_op_desc == nullptr, continue); - GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInput(*owner_node_op_desc, dst_reuse_input) != SUCCESS, - GELOGI("Get dst_reuse_input failed")); - GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInputIndex(*owner_node_op_desc, dst_reuse_input_index) != SUCCESS, - GELOGI("Get dst_reuse_input_index failed")); - if (dst_reuse_input && (dst_reuse_input_index == static_cast(in_anchor->GetIdx()))) { - out_count_reuse_input += 1; - reuse_input = true; - } - } - } - block->ref_count_ = reuse_input ? out_count_reuse_input + out_count - 1 : out_count; + block->ref_count_ = (reuse_input && out_count != 0) ? (block->ref_count_ + out_count - 1) + : (block->ref_count_ + out_count); return block; } @@ -1484,12 +1464,25 @@ void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_mapGetName().c_str()); - if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && - (node_type_indexs.back().index == static_cast(in_anchor->GetPeerOutAnchor()->GetIdx()))) { + bool is_block_matched = false; + for (auto &node_type_index : node_type_indexs) { + is_block_matched = (node_type_index.node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && + (node_type_index.index == static_cast(in_anchor->GetPeerOutAnchor()->GetIdx())); + if (is_block_matched) { + GELOGI("Block of peer out is matched. Peer node:%s, output index:%u, " + "current node:%s, input index:%d, block ref_count:%d.", + node_type_index.node->GetName().c_str(), node_type_index.index, + node->GetName().c_str(), in_anchor->GetIdx(), block->ref_count_); + break; + } + } + + if (is_block_matched) { ReleaseMemory(block, reusable_memory, (node->GetOpDesc()->GetStreamId() == block->stream_id_)); if (block->ref_count_ == 0 && block->same_stream_) { SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); } + break; } } } @@ -1530,6 +1523,21 @@ void CheckAndGetOpReuseEnv(const string &env, vector &env_vec, bool &op_ return; } +void BlockMemAssigner::CheckAndReleaseSuspendedBlock(const NodePtr &node, uint32_t idx, MemoryBlock *block) { + if (node == nullptr || node->GetOpDesc() == nullptr || block == nullptr) { + return; + } + int64_t stream_id = node->GetOpDesc()->GetStreamId(); + auto out_data_anchor = node->GetOutDataAnchor(static_cast(idx)); + bool is_suspended = (out_data_anchor != nullptr) && (out_data_anchor->GetPeerInDataNodesSize() == 0); + if (is_suspended) { + block->ref_count_ = (block->ref_count_ != 0) ? (block->ref_count_) : (1); + stream_workspace_blocks_[block->memory_type_][stream_id].emplace_back(block); + GELOGI("The output is suspended, and will be released in allocation of next node. Name:%s, index:%u, " + "size:%zu, ref_count:%d.", node->GetName().c_str(), idx, block->Size(), block->ref_count_); + } +} + Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector &ranges) { auto op_desc = node->GetOpDesc(); int64_t stream_id = op_desc->GetStreamId(); @@ -1560,7 +1568,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector // Allocate memory for the current node and release node memory of the same size in the workspace GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", for (auto iter = stream_workspace_blocks_.begin(); iter != stream_workspace_blocks_.end(); - ++iter) { ReleaseMemorys(iter->second[stream_id], reusable_blocks_[iter->first][stream_id]); }); + ++iter) { ReleaseMemorys(iter->second[stream_id], reusable_blocks_[iter->first][stream_id]); + iter->second[stream_id].clear();}); if (IsContinuousOutput(node)) { return ApplyContinuousMemory(node, ranges, is_op_reuse_mem_); } @@ -1621,6 +1630,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector continue; } symbol_blocks_[iter->second] = mem_block; + // The output is suspended, and will be released in allocation of next node. + CheckAndReleaseSuspendedBlock(node, i, mem_block); } } return SUCCESS; @@ -1648,9 +1659,6 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { if (AssignOutputMemoryWithReuse(n, ranges) != SUCCESS) { return; } - for (auto iter = stream_workspace_blocks_.begin(); iter != stream_workspace_blocks_.end(); ++iter) { - iter->second[stream_id].clear(); - } vector temp; int64_t tatal_size = 0; GetNodeWorkSpaceSize(n, temp, tatal_size); @@ -1692,6 +1700,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { kWorkspace, n, static_cast(i), workspace_reuse_flag, is_op_reuse_mem_, false, memory_type); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); + ++(mem_block->ref_count_); CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block, memory_type); } for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { diff --git a/ge/graph/build/memory/block_mem_assigner.h b/ge/graph/build/memory/block_mem_assigner.h index 4401108d..d0128dd5 100755 --- a/ge/graph/build/memory/block_mem_assigner.h +++ b/ge/graph/build/memory/block_mem_assigner.h @@ -454,6 +454,8 @@ class BlockMemAssigner : public MemAssigner { void MarkContinuousAllocedForOneInputFromVariable(const NodePtr &node); + void CheckAndReleaseSuspendedBlock(const NodePtr &node, uint32_t idx, MemoryBlock *block); + std::unordered_map>> reusable_blocks_; std::unordered_map>> stream_workspace_blocks_;