From 4526ce6845c782a573bb97bf53c31dc341ffd8b7 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Wed, 31 Mar 2021 17:24:39 +0800 Subject: [PATCH] fix reallocate memory bug for communication op --- .../backend/session/anf_runtime_algorithm.cc | 8 ++++++++ .../backend/session/anf_runtime_algorithm.h | 2 ++ mindspore/ccsrc/runtime/device/kernel_info.cc | 7 +++++++ mindspore/ccsrc/runtime/device/kernel_info.h | 1 + mindspore/ccsrc/runtime/device/kernel_runtime.cc | 16 ++++++++++++++-- 5 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 288f7c9de0..02108cb7f1 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -884,6 +884,14 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ return kernel_info->OutputAddrExist(output_idx); } +bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` + auto kernel_info = static_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->WorkspaceAddrExist(output_idx); +} + const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index bb0a7cb254..2ed8a160dc 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -153,6 +153,8 @@ class AnfRuntimeAlgorithm { static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); // check whether output addr is exist or not static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); + // check whether workspace addr is exist or not + static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx); // get address from prev node,input_index is the input index of current node related to prev node static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, bool visit_nop_node = true); diff --git a/mindspore/ccsrc/runtime/device/kernel_info.cc b/mindspore/ccsrc/runtime/device/kernel_info.cc index a7a500ff95..f85579221e 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.cc +++ b/mindspore/ccsrc/runtime/device/kernel_info.cc @@ -81,6 +81,13 @@ DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const { return workspace_address_list_[index]; } +bool KernelInfo::WorkspaceAddrExist(size_t index) const { + if (index >= workspace_address_list_.size()) { + return false; + } + return workspace_address_list_[index] != nullptr; +} + bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { if (workspace_address_list_.empty()) { // parameter and valuenode diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h index 7f8d17e0aa..33a80a5575 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.h +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -55,6 +55,7 @@ class KernelInfo : public KernelInfoDevice { bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); DeviceAddress *GetWorkspaceAddr(size_t index) const; DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; + bool WorkspaceAddrExist(size_t index) const; bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); kernel::KernelMod *MutableKernelMod() const; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index c4e2bf52d5..69bab8fa4f 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -454,8 +454,8 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode std::vector align_size_list; for (uint64_t mem_size : output_sizes) { if (AnfAlgo::OutputAddrExist(node, output_index++)) { - MS_LOG(INFO) << "communication op addr exist"; - continue; + MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address"; + return; } if (context_ptr->get_param(MS_CTX_ENABLE_HCCL)) { mem_size = mem_manager_->GetCommonAlignSize(mem_size); @@ -464,6 +464,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode align_size_list.emplace_back(mem_size); } + if (align_size_list.empty()) { + return; + } + if (type == kReuseDynamicMem) { // reuse communication op's all outputs' memory type = kReuseDynamicCommMem; @@ -533,6 +537,10 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP for (size_t i = 0; i < input_num; ++i) { auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); auto input_node = input_node_with_index.first; + if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) { + MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address"; + return; + } DeviceAddressPtr address = nullptr; if (input_node->isa()) { address = PreAssignCNodeMemory(input_node, input_node_with_index.second); @@ -811,6 +819,10 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(kernel_mod); size_t index = 0; for (auto &size : kernel_mod->GetWorkspaceSizeList()) { + if (AnfAlgo::WorkspaceAddrExist(node, index)) { + MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address"; + return; + } auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); index++;