diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index ad0e093d7f..839229be36 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -137,6 +137,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { if (is_enable_dynamic_mem) { // Use the dynamic memory pool. InitKernelRefCount(graph); + InitMemorySwapInfo(graph); InitKernelOutputAddress(graph); } else { AssignDynamicMemory(graph); @@ -144,27 +145,24 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { } bool GPUKernelRuntime::Run(session::KernelGraph *graph) { + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); bool ret = true; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); - auto iter = mem_swap_map_.find(graph); - if (iter == mem_swap_map_.end()) { - GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); - iter = mem_swap_map_.emplace(graph, std::make_shared(gpu_mem_copy_manager)).first; - } - mem_swap_manager_ = iter->second; - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); if (is_enable_dynamic_mem && !is_enable_pynative_infer) { + auto graph_id = graph->graph_id(); + auto iter = mem_swap_map_.find(graph_id); + if (iter == mem_swap_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory swap map failed."; + } + mem_swap_manager_ = iter->second; + MS_EXCEPTION_IF_NULL(mem_swap_manager_); while (!LaunchKernelDynamic(graph)) { - ClearKernelOutputAddress(graph); - if (!mem_swap_manager_->mem_swap_init()) { - mem_swap_manager_->Init(graph); - } - if (!mem_swap_manager_->RetreatSwapInfo()) { + MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; + if (!UpdateMemorySwapInfo(graph)) { return false; } } @@ -197,6 +195,16 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; } +void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); + MS_EXCEPTION_IF_NULL(gpu_mem_copy_manager); + MemSwapManagerPtr mem_swap_manager = std::make_shared(gpu_mem_copy_manager); + MS_EXCEPTION_IF_NULL(mem_swap_manager); + auto graph_id = graph->graph_id(); + mem_swap_map_[graph_id] = mem_swap_manager; +} + void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto &kernels = graph->execution_order(); @@ -227,7 +235,6 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap if (!AnfAlgo::OutputAddrExist(kernel, i)) { continue; } - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); if (device_address->ptr_) { mem_manager_->FreeMemFromMemPool(device_address); @@ -239,9 +246,12 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); auto graph_id = graph->graph_id(); - auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; + auto iter = mem_reuse_util_map_.find(graph_id); + if (iter == mem_reuse_util_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory reuse map failed."; + } + auto mem_reuse_util_ptr = iter->second; MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); // Reset the reference count. mem_reuse_util_ptr->ResetDynamicUsedRefCount(); @@ -263,27 +273,14 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_LOG(EXCEPTION) << "Launch kernel failed."; } FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); - - if (mem_swap_manager_->trigger_swap() && mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - if (!AddMemSwapTask(kernel)) { - return false; - } - } - - if (mem_swap_manager_->trigger_swap()) { - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - } + UpdateMemorySwapTask(kernel); } - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - if (mem_swap_manager_->trigger_swap()) { - mem_swap_manager_->ClearSwapQueue(); - } + ClearSwapQueue(); return true; } -bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { +bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); for (auto &mem_swap_info : mem_swap_info_list) { @@ -311,14 +308,92 @@ bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { return true; } +bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + ClearKernelOutputAddress(graph); + if (!mem_swap_manager_->mem_swap_init()) { + mem_swap_manager_->Init(graph); + } + return mem_swap_manager_->RetreatSwapInfo(); +} + +bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return true; + } + if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + if (!AddMemorySwapTask(kernel)) { + return false; + } + } + CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed."); + return true; +} + +void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + auto status = device_address->status(); + switch (status) { + case DeviceAddressStatus::kInDevice: + break; + case DeviceAddressStatus::kInDeviceToHost: { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + break; + } + case DeviceAddressStatus::kInHostToDevice: { + while (device_address->status() != DeviceAddressStatus::kInDevice) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + } + break; + } + case DeviceAddressStatus::kInHost: + MS_LOG(ERROR) << "Invaild device address status:" << status; + break; + default: + MS_LOG(EXCEPTION) << "Invaild device address status:" << status; + } +} + +void GPUKernelRuntime::UpdateDeviceSwapQueue() { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } +} + +void GPUKernelRuntime::ClearSwapQueue() { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + mem_swap_manager_->ClearSwapQueue(); +} + bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { MS_EXCEPTION_IF_NULL(mem_manager_); + MS_EXCEPTION_IF_NULL(mem_swap_manager_); auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); if (!ret) { if (!mem_swap_manager_->trigger_swap()) { return false; } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { @@ -326,7 +401,6 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, mem_manager_->FreeMemFromMemPool(device_address_swap_out); } } - ret = mem_manager_->MallocMemFromMemPool(device_address, size); if (!ret) { return false; @@ -337,12 +411,12 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, void *GPUKernelRuntime::AttemptMallocMem(size_t size) { MS_EXCEPTION_IF_NULL(mem_manager_); + MS_EXCEPTION_IF_NULL(mem_swap_manager_); auto device_ptr = mem_manager_->MallocMemFromMemPool(size); if (!device_ptr) { if (!mem_swap_manager_->trigger_swap()) { return nullptr; } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { @@ -350,7 +424,6 @@ void *GPUKernelRuntime::AttemptMallocMem(size_t size) { mem_manager_->FreeMemFromMemPool(device_address_swap_out); } } - device_ptr = mem_manager_->MallocMemFromMemPool(size); if (!device_ptr) { return nullptr; @@ -377,40 +450,11 @@ bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_address); - if (mem_swap_manager_->trigger_swap()) { - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { - device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); - } - - auto status = device_address->status(); - switch (status) { - case DeviceAddressStatus::kInDevice: - break; - case DeviceAddressStatus::kInHost: - break; - case DeviceAddressStatus::kInDeviceToHost: { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); - device_address->set_status(DeviceAddressStatus::kInDevice); - break; - } - case DeviceAddressStatus::kInHostToDevice: { - while (device_address->status() != DeviceAddressStatus::kInDevice) { - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { - device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); - } - } - break; - } - default: - MS_LOG(ERROR) << "Invaild device address status"; - return false; - } - } + UpdateHostSwapQueue(device_address); MS_EXCEPTION_IF_NULL(device_address->ptr_); kernel::AddressPtr input = std::make_shared(); MS_EXCEPTION_IF_NULL(input); @@ -426,16 +470,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern AddressPtrList *kernel_outputs) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_outputs); - MS_EXCEPTION_IF_NULL(mem_manager_); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - if (mem_swap_manager_->trigger_swap()) { - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - } + UpdateDeviceSwapQueue(); auto output_sizes = kernel_mod.GetOutputSizeList(); for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index ea3ab17160..bc7e4ed22c 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -53,9 +53,9 @@ class GPUKernelRuntime : public KernelRuntime { // The related functions and members for using dynamic memory pool. void InitKernelRefCount(const session::KernelGraph *graph); void InitKernelOutputAddress(const session::KernelGraph *graph); + void InitMemorySwapInfo(const session::KernelGraph *graph); void ClearKernelOutputAddress(const session::KernelGraph *graph); bool LaunchKernelDynamic(const session::KernelGraph *graph); - bool AddMemSwapTask(const AnfNodePtr &kernel); bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); void *AttemptMallocMem(size_t size); bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, @@ -74,8 +74,14 @@ class GPUKernelRuntime : public KernelRuntime { std::vector size_list); void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, uint32_t graph_id); + bool AddMemorySwapTask(const AnfNodePtr &kernel); + bool UpdateMemorySwapInfo(const session::KernelGraph *graph); + bool UpdateMemorySwapTask(const AnfNodePtr &kernel); + void UpdateHostSwapQueue(const DeviceAddressPtr device_address); + void UpdateDeviceSwapQueue(); + void ClearSwapQueue(); std::unordered_map mem_reuse_util_map_; - std::unordered_map mem_swap_map_; + std::unordered_map mem_swap_map_; MemSwapManagerPtr mem_swap_manager_{nullptr}; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index 7765e93758..8d6d176970 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -187,8 +187,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList GetSummaryNodes(graph.get()); // Remove NoOp from execution graph opt::RemoveNopNode(graph.get()); - // Alloc memory, including static memory and dynamic memory - AllocateMemory(graph.get()); + // Set graph manager. MS_EXCEPTION_IF_NULL(context_); FuncGraphManagerPtr manager = MakeManager({graph}); context_->AddManager(manager); @@ -196,6 +195,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList manager->AddFuncGraph(graph); graph->set_manager(manager); } + // Alloc memory, including static memory and dynamic memory + AllocateMemory(graph.get()); return graph_id; }