|
|
|
@ -458,7 +458,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|
|
|
|
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(address);
|
|
|
|
|
if (output_ptr == nullptr) {
|
|
|
|
|
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address);
|
|
|
|
|
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_ptr);
|
|
|
|
|
} else {
|
|
|
|
|
address->set_ptr(output_ptr);
|
|
|
|
@ -515,8 +515,17 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
|
|
|
|
|
MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->inputs().size() < 2) {
|
|
|
|
|
// communication node's input should contain itself and at least on input
|
|
|
|
|
MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto first_input_node = cnode->input(1);
|
|
|
|
|
auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
|
|
|
|
|
uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
|
|
|
|
|
addr_size[0].first, true);
|
|
|
|
|
for (const auto &iter : addr_size) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter.first);
|
|
|
|
|
iter.first->set_ptr(input_ptr);
|
|
|
|
@ -568,7 +577,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
|
|
|
|
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
|
|
|
|
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
|
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address);
|
|
|
|
|
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ptr);
|
|
|
|
|
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
|
|
|
|
|
AnfAlgo::SetOutputAddr(device_address, i, node.get());
|
|
|
|
|