|
|
|
@ -411,7 +411,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|
|
|
|
}
|
|
|
|
|
auto tensor_size = CountNodeDeviceMemorySize(item, index);
|
|
|
|
|
auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
|
|
|
|
if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) {
|
|
|
|
|
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputAddr(address, index, item.get());
|
|
|
|
@ -517,7 +517,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|
|
|
|
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(address);
|
|
|
|
|
if (output_ptr == nullptr) {
|
|
|
|
|
output_ptr = mem_manager_->MallocMem(address, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0));
|
|
|
|
|
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_ptr);
|
|
|
|
|
} else {
|
|
|
|
|
address->set_ptr(output_ptr);
|
|
|
|
@ -565,8 +565,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
|
|
|
|
|
if (addr_size.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
uint8_t *input_ptr =
|
|
|
|
|
mem_manager_->MallocMem(addr_size[0].first, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0));
|
|
|
|
|
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first);
|
|
|
|
|
for (const auto &iter : addr_size) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter.first);
|
|
|
|
|
iter.first->set_ptr(input_ptr);
|
|
|
|
@ -600,8 +599,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
|
|
|
|
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
|
|
|
|
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
|
uint8_t *ptr =
|
|
|
|
|
mem_manager_->MallocMem(device_address, type, output_sizes[i], std::pair<AnfNodePtr, size_t>(node, i));
|
|
|
|
|
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ptr);
|
|
|
|
|
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
|
|
|
|
|
AnfAlgo::SetOutputAddr(device_address, i, node.get());
|
|
|
|
@ -639,7 +637,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|
|
|
|
MS_EXCEPTION_IF_NULL(address);
|
|
|
|
|
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
|
|
|
|
|
} else if (mem_manager_->MallocMem(address, kStaticMem, node_size) == nullptr) {
|
|
|
|
|
} else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
|
|
|
|
@ -675,7 +673,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(address);
|
|
|
|
|
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
|
|
|
|
|
} else if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) {
|
|
|
|
|
} else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
|
|
|
@ -859,8 +857,8 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
|
|
|
|
|
auto device_address = CreateDeviceAddress(nullptr, size, format, type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
|
|
|
|
auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size);
|
|
|
|
|
device_address->set_ptr(base_ptr);
|
|
|
|
|
auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size, device_address);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(base_ptr);
|
|
|
|
|
return device_address;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|