From 301b7039a86c66fd09f73a71fa845213e35b509b Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Thu, 12 Nov 2020 10:03:19 +0800 Subject: [PATCH] fix communication input address bug --- .../ccsrc/runtime/device/kernel_runtime.cc | 17 +++++++++++++---- .../ccsrc/runtime/device/memory_manager.cc | 4 ++-- mindspore/ccsrc/runtime/device/memory_manager.h | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index e634fd6c45..eaec96454d 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -458,7 +458,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type); MS_EXCEPTION_IF_NULL(address); if (output_ptr == nullptr) { - output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address); + output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true); MS_EXCEPTION_IF_NULL(output_ptr); } else { address->set_ptr(output_ptr); @@ -515,8 +515,17 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input."; } } - - uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < 2) { + // communication node's input should contain itself and at least on input + MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope(); + return; + } + auto first_input_node = cnode->input(1); + auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true); + uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size, + addr_size[0].first, true); for (const auto &iter : addr_size) { MS_EXCEPTION_IF_NULL(iter.first); iter.first->set_ptr(input_ptr); @@ -568,7 +577,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); MS_EXCEPTION_IF_NULL(device_address); - uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address); + uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false); MS_EXCEPTION_IF_NULL(ptr); device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); AnfAlgo::SetOutputAddr(device_address, i, node.get()); diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc index 102c1c5a2b..31ec68a827 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -83,13 +83,13 @@ void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph *graph) { } uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, - const DeviceAddressPtr &address) { + const DeviceAddressPtr &address, bool comm_mem) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(address); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); uint8_t *ptr = nullptr; - if (AnfAlgo::IsCommunicationOp(node)) { + if (comm_mem) { bool communication_mem = false; if (context_ptr->get_param(MS_CTX_ENABLE_HCCL)) { communication_mem = true; diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h index ffde7c8a90..6fce89c881 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -46,7 +46,7 @@ class MemoryManager { void MallocReusedDynamicMem(const session::KernelGraph *graph); void MallocSomasDynamicMem(const session::KernelGraph *graph); uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, - const DeviceAddressPtr &address); + const DeviceAddressPtr &address, bool comm_mem); uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address);