From 99f12f9105da9fe8d787d90b8d5f5f0e88627c67 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 14 Apr 2020 16:34:28 +0800 Subject: [PATCH] gpu uses dynamic memory pool by default --- .../ccsrc/device/gpu/gpu_kernel_runtime.cc | 25 +++++------- mindspore/ccsrc/device/memory_manager.cc | 7 ++++ mindspore/ccsrc/device/memory_manager.h | 1 + .../ccsrc/pre_activate/mem_reuse/mem_reuse.cc | 39 +++++++------------ mindspore/ccsrc/utils/context/ms_context.cc | 2 +- 5 files changed, 34 insertions(+), 40 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 671a11f776..584f66eee7 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -127,9 +127,10 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); + bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); struct timeval start_time, end_time; (void)gettimeofday(&start_time, nullptr); - if (is_enable_dynamic_mem) { + if (is_enable_dynamic_mem && !is_enable_pynative_infer) { ret = LaunchKernelDynamic(graph); } else { ret = LaunchKernel(graph); @@ -152,7 +153,7 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { } mem_reuse_util_ptr->SetKernelDefMap(); mem_reuse_util_ptr->SetReuseRefCount(); - // Can't free the device address of graph output, so set the reference count of graph output specially, + // Can't free the device address of graph output, so set the reference count of graph output specially. mem_reuse_util_ptr->SetGraphOutputRefCount(); mem_reuse_util_ptr_ = mem_reuse_util_ptr; } @@ -351,6 +352,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, if (kernel_ref_count_ptr == nullptr) { continue; } + // Can't free the output of graph. + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == memreuse::kMaxRefCount) { + continue; + } kernel_ref_count_ptr->ref_count_dynamic_use_--; if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { // Reset the reference count. @@ -360,14 +365,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); if (!is_communication_op) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(device_address->ptr_); - mem_manager_->FreeMemFromMemPool(device_address->ptr_); - device_address->ptr_ = nullptr; + mem_manager_->FreeMemFromMemPool(device_address); } } } - // Free the workspace of kernel. for (size_t i = 0; i < kernel_workspaces.size(); ++i) { auto workspace = kernel_workspaces[i]; @@ -388,10 +389,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr communication_op_input_ref_count_--; if (communication_op_input_ref_count_ == 0) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(device_address->ptr_); - mem_manager_->FreeMemFromMemPool(device_address->ptr_); - device_address->ptr_ = nullptr; + mem_manager_->FreeMemFromMemPool(device_address); } *is_communication_op = true; return; @@ -410,10 +408,7 @@ void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr communication_op_output_ref_count_--; if (communication_op_output_ref_count_ == 0) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(device_address->ptr_); - mem_manager_->FreeMemFromMemPool(device_address->ptr_); - device_address->ptr_ = nullptr; + mem_manager_->FreeMemFromMemPool(device_address); } *is_communication_op = true; } diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc index 6977628eb1..2fad5fc10e 100644 --- a/mindspore/ccsrc/device/memory_manager.cc +++ b/mindspore/ccsrc/device/memory_manager.cc @@ -155,6 +155,13 @@ void *MemoryManager::MallocMemFromMemPool(size_t size) { return nullptr; } +void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) { + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(address->ptr_); + FreeMemFromMemPool(address->ptr_); + address->ptr_ = nullptr; +} + void MemoryManager::FreeMemFromMemPool(void *device_ptr) { if (device_ptr == nullptr) { MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; diff --git a/mindspore/ccsrc/device/memory_manager.h b/mindspore/ccsrc/device/memory_manager.h index 82c22f4548..c90ffc380e 100644 --- a/mindspore/ccsrc/device/memory_manager.h +++ b/mindspore/ccsrc/device/memory_manager.h @@ -47,6 +47,7 @@ class MemoryManager { virtual void MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); virtual void *MallocMemFromMemPool(size_t size); + virtual void FreeMemFromMemPool(const DeviceAddressPtr address); virtual void FreeMemFromMemPool(void *device_ptr); size_t GetCommonAlignSize(size_t input_size) const; diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 2113fec653..d25b60003f 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -273,30 +273,21 @@ void MemReuseUtil::SetReuseRefCount() { } void MemReuseUtil::SetGraphOutputRefCount() { - for (const auto &output : graph_->outputs()) { - MS_EXCEPTION_IF_NULL(output); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { - if (!(output->isa())) { - continue; - } - auto cnode = output->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = cnode->input(i + 1); - MS_EXCEPTION_IF_NULL(input_node); - auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); - MS_EXCEPTION_IF_NULL(kernel_input.first); - if (!(kernel_input.first->isa())) { - continue; - } - auto ak_node = kernel_input.first->cast(); - auto key = ak_node.get(); - auto iter = kernel_output_refs_.find(key); - if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { - auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; - MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); - kernel_ref_count_ptr->ref_count_ = kMaxRefCount; - kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; - } + auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); + for (const auto &node : nodes) { + auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0); + MS_EXCEPTION_IF_NULL(kernel_input.first); + if (!kernel_input.first->isa() || !AnfAlgo::IsRealKernel(kernel_input.first)) { + continue; + } + auto ak_node = kernel_input.first->cast(); + auto key = ak_node.get(); + auto iter = kernel_output_refs_.find(key); + if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { + auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; + MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); + kernel_ref_count_ptr->ref_count_ = kMaxRefCount; + kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; } } #ifdef MEM_REUSE_DEBUG diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 6c15e16714..b1ab0205f2 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -75,7 +75,7 @@ MsContext::MsContext(const std::string& policy, const std::string& target) { precompile_only_ = false; auto_mixed_precision_flag_ = true; enable_pynative_infer_ = false; - enable_dynamic_mem_pool_ = false; + enable_dynamic_mem_pool_ = true; graph_memory_max_size_ = "0"; variable_memory_max_size_ = "0"; MS_LOG(INFO) << "Create context with backend policy:" << policy << ", device target:" << target << ".";