|
|
|
@ -401,6 +401,15 @@ bool BestFitMemReuse::IsReusableStream(uint32_t curr_stream_id, uint32_t target_
|
|
|
|
|
return curr_parallel_set.find(target_stream_id) == curr_parallel_set.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool BestFitMemReuse::IsRelease(const std::string &kernel_name) {
|
|
|
|
|
// unable_used_node include the node type that output tensor cannot be released,
|
|
|
|
|
// even if its refcount is equal to zero.
|
|
|
|
|
std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(),
|
|
|
|
|
prim::kPrimFusedBatchNorm->name(),
|
|
|
|
|
prim::kPrimFusedBatchNormGrad->name()};
|
|
|
|
|
return unable_used_node.find(kernel_name) == unable_used_node.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BestFitMemReuse::CheckTensorIndex(int tensor_index) const {
|
|
|
|
|
if (tensor_index < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "warning, please check tensor info.";
|
|
|
|
@ -437,6 +446,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
|
|
|
|
// update node input tensor refcount, and membuf list status
|
|
|
|
|
UpdateNodeInputAndMembuf(op_def_ptr.get());
|
|
|
|
|
// check node output tensor which refcount is equal to zero
|
|
|
|
|
if (IsRelease(op_def_ptr->kernel_name())) {
|
|
|
|
|
ReleaseNodeUnusedOutput(op_def_ptr.get());
|
|
|
|
|
}
|
|
|
|
|
#ifdef MEM_REUSE_DEBUG
|
|
|
|
|
MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_);
|
|
|
|
|
++op_num;
|
|
|
|
|