!8491 use input node as input parameter to allocate communication input address

From: @laiyongqiang
Reviewed-by: 
Signed-off-by:
pull/8491/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 90aaf8a58c

@ -458,7 +458,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
MS_EXCEPTION_IF_NULL(address);
if (output_ptr == nullptr) {
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address);
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
MS_EXCEPTION_IF_NULL(output_ptr);
} else {
address->set_ptr(output_ptr);
@ -515,8 +515,17 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
}
}
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < 2) {
// communication node's input should contain itself and at least on input
MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
return;
}
auto first_input_node = cnode->input(1);
auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
addr_size[0].first, true);
for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first);
iter.first->set_ptr(input_ptr);
@ -568,7 +577,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
MS_EXCEPTION_IF_NULL(device_address);
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address);
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
MS_EXCEPTION_IF_NULL(ptr);
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
AnfAlgo::SetOutputAddr(device_address, i, node.get());

@ -83,13 +83,13 @@ void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph *graph) {
}
uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
const DeviceAddressPtr &address) {
const DeviceAddressPtr &address, bool comm_mem) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(address);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
uint8_t *ptr = nullptr;
if (AnfAlgo::IsCommunicationOp(node)) {
if (comm_mem) {
bool communication_mem = false;
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
communication_mem = true;

@ -46,7 +46,7 @@ class MemoryManager {
void MallocReusedDynamicMem(const session::KernelGraph *graph);
void MallocSomasDynamicMem(const session::KernelGraph *graph);
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
const DeviceAddressPtr &address);
const DeviceAddressPtr &address, bool comm_mem);
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address);

Loading…
Cancel
Save