diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc index 560893ee51..4c7b897cac 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc @@ -21,8 +21,8 @@ namespace mindspore { namespace device { namespace ascend { -const uint64_t kAscendDeviceMemGB = 24; -const uint64_t kAscendMemPoolGB = 6; +const uint64_t kAscendDeviceMemGB = 26; +const uint64_t kAscendMemPoolGB = 4; const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc index 8a3647d980..649d34dfd2 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc @@ -401,6 +401,15 @@ bool BestFitMemReuse::IsReusableStream(uint32_t curr_stream_id, uint32_t target_ return curr_parallel_set.find(target_stream_id) == curr_parallel_set.end(); } +bool BestFitMemReuse::IsRelease(const std::string &kernel_name) { + // unable_used_node include the node type that output tensor cannot be released, + // even if its refcount is equal to zero. + std::unordered_set unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), + prim::kPrimFusedBatchNorm->name(), + prim::kPrimFusedBatchNormGrad->name()}; + return unable_used_node.find(kernel_name) == unable_used_node.end(); +} + void BestFitMemReuse::CheckTensorIndex(int tensor_index) const { if (tensor_index < 0) { MS_LOG(EXCEPTION) << "warning, please check tensor info."; @@ -437,6 +446,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { // update node input tensor refcount, and membuf list status UpdateNodeInputAndMembuf(op_def_ptr.get()); // check node output tensor which refcount is equal to zero + if (IsRelease(op_def_ptr->kernel_name())) { + ReleaseNodeUnusedOutput(op_def_ptr.get()); + } #ifdef MEM_REUSE_DEBUG MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_); ++op_num; diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h index e41c20d620..5ef16b7dc3 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h @@ -102,6 +102,8 @@ class BestFitMemReuse { size_t GetAllocatedSize(); // If the target stream can be reused by current stream bool IsReusableStream(uint32_t curr_stream_id, uint32_t target_stream_id); + // return false, when the node output cannot be released + bool IsRelease(const std::string &kernel_name); // set tensor_def and op_def void set_tensor_ptr_list(const std::vector &tensor_ptr_list) { tensor_ptr_list_ = tensor_ptr_list;