From 2891f0d20df3363fad0911aa3936b708440ea078 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 21 Apr 2020 10:19:16 +0800 Subject: [PATCH] gpu dynamic memory pool supports multi-allReduce --- .../ccsrc/device/gpu/gpu_kernel_runtime.cc | 122 ++++-------------- .../ccsrc/device/gpu/gpu_kernel_runtime.h | 3 - .../ccsrc/device/gpu/gpu_memory_manager.cc | 4 + .../ccsrc/device/gpu/gpu_memory_manager.h | 2 + mindspore/ccsrc/device/kernel_runtime.h | 2 +- mindspore/ccsrc/device/memory_manager.cc | 23 ++++ mindspore/ccsrc/device/memory_manager.h | 4 + .../kernel/gpu/arrays/transpose_gpu_kernel.h | 2 +- .../kernel/gpu/cuda_impl/unary_op_impl.cu | 16 +++ .../kernel/gpu/cuda_impl/unary_op_impl.cuh | 3 + .../kernel/gpu/math/unary_op_gpu_kernel.h | 1 + .../mem_reuse/mem_dynamic_allocator.cc | 31 +++++ .../mem_reuse/mem_dynamic_allocator.h | 2 + .../ccsrc/pre_activate/mem_reuse/mem_reuse.cc | 8 -- tests/st/nccl/test_nccl_all_reduce_op.py | 2 +- 15 files changed, 117 insertions(+), 108 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 11b8bdc162..5dd4facb25 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -111,7 +111,8 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(mem_manager_); mem_manager_->ResetDynamicMemory(); - AssignStaticMemory(graph); + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); if (is_enable_dynamic_mem) { // Use the dynamic memory pool. @@ -181,7 +182,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto graph_id = graph->graph_id(); - // The inputs and outputs memory of communication kernel are special, so separate processing. + // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); auto &kernels = graph->execution_order(); @@ -229,15 +230,12 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); - auto device_ptr = device_address->ptr_; - if (device_ptr == nullptr) { - device_ptr = mem_manager_->MallocMemFromMemPool(output_sizes[i]); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->ptr_ = device_ptr; + if (device_address->ptr_ == nullptr) { + mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); } kernel::AddressPtr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); - output->addr = device_ptr; + output->addr = device_address->ptr_; output->size = output_sizes[i]; kernel_outputs->push_back(output); } @@ -267,7 +265,6 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph if (kernel_name == kAllReduceOpName) { AllocCommunicationOpInputDynamicRes(kernel); AllocCommunicationOpOutputDynamicRes(kernel); - return; } } } @@ -275,48 +272,30 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - // The reference count of communication kernel input is not 0. - if (communication_op_input_ref_count_ != 0) { - MS_LOG(ERROR) << "The reference count of communication kernel input is not 0."; - return; - } - - size_t total = 0; - std::vector> addr_size; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); // The inputs of communication kernel are not released. - if ((i == 0) && (device_address->ptr_ != nullptr)) { - MS_LOG(ERROR) << "The inputs of communication kernel are not released."; - return; + if (device_address->ptr_ != nullptr) { + MS_LOG(INFO) << "The inputs of communication kernel are not released."; + mem_manager_->FreeMemFromMemPool(device_address); } - auto output_size = device_address->size_; - total += output_size; - addr_size.emplace_back(device_address.get(), output_size); - } - - auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); - MS_EXCEPTION_IF_NULL(device_mem_ptr); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(device_mem_ptr); - communication_op_input_ref_count_++; - device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); + total_size += device_address->size_; + size_list.emplace_back(device_address->size_); + addr_list.emplace_back(device_address); } + mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - // The reference count of communication kernel output is not 0. - if (communication_op_output_ref_count_ != 0) { - MS_LOG(ERROR) << "The reference count of communication kernel output is not 0."; - return; - } - - size_t total = 0; - std::vector> addr_size; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); @@ -324,22 +303,15 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); // The outputs of communication kernel are not released. - if ((i == 0) && (device_address->ptr_ != nullptr)) { - MS_LOG(ERROR) << "The outputs of communication kernel are not released."; - return; + if (device_address->ptr_ != nullptr) { + MS_LOG(INFO) << "The outputs of communication kernel are not released."; + mem_manager_->FreeMemFromMemPool(device_address); } - total += output_sizes[i]; - addr_size.emplace_back(device_address.get(), output_sizes[i]); - } - - auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); - MS_EXCEPTION_IF_NULL(device_mem_ptr); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(device_mem_ptr); - communication_op_output_ref_count_++; - device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); + total_size += output_sizes[i]; + size_list.emplace_back(output_sizes[i]); + addr_list.emplace_back(device_address); } + mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, @@ -362,14 +334,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } kernel_ref_count_ptr->ref_count_dynamic_use_--; if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + mem_manager_->FreeMemFromMemPool(device_address); // Reset the reference count. kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; - bool is_communication_op = false; - FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); - if (!is_communication_op) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - mem_manager_->FreeMemFromMemPool(device_address); - } } } // Free the output of kernel, if output has no reference. @@ -393,40 +361,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } } } - -void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, - bool *is_communication_op) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - // The inputs memory of communication kernel is one piece memory, need release together. - if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) { - communication_op_input_ref_count_--; - if (communication_op_input_ref_count_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); - mem_manager_->FreeMemFromMemPool(device_address); - } - *is_communication_op = true; - return; - } - - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << cnode->inputs().size() - 1 - << "."; - } - auto input_node = cnode->input(input_idx + 1); - auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); - // The outputs memory of communication kernel is one piece memory, need release together. - if (AnfAlgo::GetCNodeName(kernel_input.first) == kAllReduceOpName) { - communication_op_output_ref_count_--; - if (communication_op_output_ref_count_ == 0) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); - mem_manager_->FreeMemFromMemPool(device_address); - } - *is_communication_op = true; - } -} } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index e0eb2dc3f1..33d4b4be70 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -60,9 +60,6 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, uint32_t graph_id); - void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op); - size_t communication_op_input_ref_count_{0}; - size_t communication_op_output_ref_count_{0}; std::unordered_map mem_reuse_util_map_; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc index 8bb65963d8..6e81130b9c 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc @@ -29,6 +29,10 @@ void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); } +std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); +} + void GPUMemoryManager::MallocDeviceMemory() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h index cc5dac2a5e..c79fb9cc22 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#include #include "device/memory_manager.h" namespace mindspore { namespace device { @@ -30,6 +31,7 @@ class GPUMemoryManager : public MemoryManager { void *MallocMemFromMemPool(size_t size) override; void FreeMemFromMemPool(void *device_ptr) override; + std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); protected: uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 8f4f769f55..b15cb31e17 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -67,6 +67,7 @@ class KernelRuntime { TypeId type_id) = 0; virtual bool SyncStream() = 0; void AssignStaticMemory(session::KernelGraph *graph); + void AssignStaticMemoryValueNode(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph); void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); @@ -81,7 +82,6 @@ class KernelRuntime { private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); - void AssignStaticMemoryValueNode(session::KernelGraph *graph); void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc index 2fad5fc10e..dce54495b0 100644 --- a/mindspore/ccsrc/device/memory_manager.cc +++ b/mindspore/ccsrc/device/memory_manager.cc @@ -167,5 +167,28 @@ void MemoryManager::FreeMemFromMemPool(void *device_ptr) { MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; } } + +void MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); + if (addr_list.size() != device_ptr_list.size()) { + MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; + } + for (size_t i = 0; i < addr_list.size(); i++) { + MS_EXCEPTION_IF_NULL(device_ptr_list[i]); + MS_EXCEPTION_IF_NULL(addr_list[i]); + addr_list[i]->ptr_ = device_ptr_list[i]; + addr_list[i]->from_mem_pool_ = true; + } +} + +std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + if (total_size == 0) { + MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; + } + std::vector device_ptr_list; + device_ptr_list.emplace_back(nullptr); + return device_ptr_list; +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/memory_manager.h b/mindspore/ccsrc/device/memory_manager.h index c90ffc380e..dae0861506 100644 --- a/mindspore/ccsrc/device/memory_manager.h +++ b/mindspore/ccsrc/device/memory_manager.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ #include +#include #include "pre_activate/mem_reuse/mem_reuse.h" #include "pre_activate/mem_reuse/mem_reuse_allocator.h" namespace mindspore { @@ -49,6 +50,9 @@ class MemoryManager { virtual void *MallocMemFromMemPool(size_t size); virtual void FreeMemFromMemPool(const DeviceAddressPtr address); virtual void FreeMemFromMemPool(void *device_ptr); + virtual void MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); size_t GetCommonAlignSize(size_t input_size) const; size_t GetCommunicationAlignSize(size_t input_size) const; diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h index 198e8687fc..1c9cf925ea 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h @@ -44,7 +44,7 @@ class TransposeGpuFwdKernel : public GpuKernel { "cudaMemcpyAsync input_shape failed"); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcphalfyAsync input_axis failed"); + "cudaMemcpyAsync input_axis failed"); int size = SizeToInt(input_size_ / sizeof(T)); CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, reinterpret_cast(stream_ptr)); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu index 6022485251..5e7a25b8e6 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu @@ -60,6 +60,14 @@ __global__ void SquareKernel(T *input, T *output, size_t count) { return; } template +__global__ void ZeroslikeKernel(T *output, size_t count) { + T zero = 0.0; + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = zero; + } + return; +} +template void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -84,13 +92,21 @@ void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { SquareKernel<<>>(input, output, count); return; } +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { + ZeroslikeKernel<<>>(output, count); + return; +} + template void Exponential(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Negative(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Square(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Square(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh index f303c73d29..8ba9cb4a52 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh @@ -28,4 +28,7 @@ template void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h index 5b2414f8f1..d8fea7370b 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h @@ -81,6 +81,7 @@ class UnaryOpGpuKernel : public GpuKernel { break; } case UNARY_OP_ZEROSLIKE: { + Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; } default: { diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc index c9ef381f16..b7280f52ae 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc @@ -36,6 +36,37 @@ DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { return device_addr; } +std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, + std::vector size_list) { + // Pre-alloc the one whole piece memory. + auto device_addr = AllocTensorMem(total_size); + MS_EXCEPTION_IF_NULL(device_addr); + // Remove the pre-alloc memory. + auto mem_block = FindMemBlock(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + auto rest_size = mem_buf->size_ - total_size; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + // Split the pre-alloc memory into continuous memory by the size list. + DynamicMemBufPtr continuous_mem_buf; + std::vector device_addr_list; + auto buf_addr = device_addr; + for (size_t i = 0; i < size_list.size(); i++) { + continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); + (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); + device_addr_list.emplace_back(buf_addr); + buf_addr = AddressOffset(buf_addr, size_list[i]); + } + // Update the size of the last memory buf. + continuous_mem_buf->size_ += rest_size; + return device_addr_list; +} + size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { if (size == 0) { return DYNAMIC_MEM_ALIGN_SIZE; diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h index c628756070..07efa267aa 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h @@ -79,6 +79,8 @@ class DynamicMemPoolBestFit { virtual ~DynamicMemPoolBestFit(); // The main program entry of memory alloc. DeviceMemPtr AllocTensorMem(size_t size); + // The main program entry of continuous memory alloc. + std::vector AllocContinuousTensorMem(size_t total_size, std::vector size_list); // The main program entry of memory free. void FreeTensorMem(const DeviceMemPtr device_addr); // Release the real device memory. diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index d25b60003f..952dfe97e4 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -162,10 +162,6 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr if (iter == kernel_def_ptr->inputs_.end()) { kernel_def_ptr->inputs_[key].push_back(ref_ptr); } else { - if (std::any_of(iter->second.begin(), iter->second.end(), - [ref_ptr](const KernelRefCountPtr &it) { return (it.get() == ref_ptr.get()); })) { - break; - } iter->second.push_back(ref_ptr); } } @@ -185,10 +181,6 @@ void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_pt if (iter == kernel_def_ptr->outputs_.end()) { kernel_def_ptr->outputs_[key].push_back(kernel_ref); } else { - if (std::any_of(iter->second.begin(), iter->second.end(), - [kernel_ref](const KernelRefCountPtr &it) { return (it == kernel_ref); })) { - break; - } iter->second.push_back(kernel_ref); } } diff --git a/tests/st/nccl/test_nccl_all_reduce_op.py b/tests/st/nccl/test_nccl_all_reduce_op.py index 7c2e579463..3ba8b219e4 100644 --- a/tests/st/nccl/test_nccl_all_reduce_op.py +++ b/tests/st/nccl/test_nccl_all_reduce_op.py @@ -20,7 +20,7 @@ import mindspore.context as context from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size -context.set_context(mode=context.GRAPH_MODE, device_target='GPU', enable_dynamic_memory=False) +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') init('nccl') rank = get_rank()