!3882 Decoupling the interface of mallocing mem

Merge pull request !3882 from JoyLvliang/decoupling-the-interface-of-mallocing-mem
pull/3882/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0d2ef3ba38

@ -411,7 +411,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
auto tensor_size = CountNodeDeviceMemorySize(item, index);
auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) {
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
}
AnfAlgo::SetOutputAddr(address, index, item.get());
@ -517,7 +517,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
MS_EXCEPTION_IF_NULL(address);
if (output_ptr == nullptr) {
output_ptr = mem_manager_->MallocMem(address, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0));
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address);
MS_EXCEPTION_IF_NULL(output_ptr);
} else {
address->set_ptr(output_ptr);
@ -565,8 +565,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
if (addr_size.empty()) {
return;
}
uint8_t *input_ptr =
mem_manager_->MallocMem(addr_size[0].first, type, total_size, std::pair<AnfNodePtr, size_t>(node, 0));
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first);
for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first);
iter.first->set_ptr(input_ptr);
@ -600,8 +599,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
MS_EXCEPTION_IF_NULL(device_address);
uint8_t *ptr =
mem_manager_->MallocMem(device_address, type, output_sizes[i], std::pair<AnfNodePtr, size_t>(node, i));
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address);
MS_EXCEPTION_IF_NULL(ptr);
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
AnfAlgo::SetOutputAddr(device_address, i, node.get());
@ -639,7 +637,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
MS_EXCEPTION_IF_NULL(address);
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) {
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
} else if (mem_manager_->MallocMem(address, kStaticMem, node_size) == nullptr) {
} else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
}
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
@ -675,7 +673,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(address);
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
} else if (mem_manager_->MallocMem(address, kStaticMem, tensor_size) == nullptr) {
} else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
}
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
@ -859,8 +857,8 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
auto device_address = CreateDeviceAddress(nullptr, size, format, type);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_);
auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size);
device_address->set_ptr(base_ptr);
auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size, device_address);
MS_EXCEPTION_IF_NULL(base_ptr);
return device_address;
}

@ -45,8 +45,10 @@ void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) {
mem_reuse_util_ptr_->set_mem_base(base_ptr);
}
uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
const DeviceAddressPtr &address) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(address);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
uint8_t *ptr = nullptr;
@ -57,23 +59,30 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me
}
if (type == kStaticMem) {
ptr = MallocStaticMem(size, communication_mem);
address->from_mem_pool_ = true;
if (communication_mem) {
address->communication_ptr_ = ptr - kMemAlignSize;
}
} else if (type == kReuseDynamicCommMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
} else {
ptr = MallocDynamicMem(size, communication_mem);
}
address->ptr_ = ptr;
return ptr;
}
if (type == kStaticMem) {
ptr = MallocStaticMem(size, false);
address->from_mem_pool_ = true;
} else if (type == kDynamicMem) {
ptr = MallocDynamicMem(size, false);
} else if (type == kReuseDynamicMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
}
address->ptr_ = ptr;
return ptr;
}
@ -85,38 +94,16 @@ uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index,
return MallocDynamicMem(size, false);
}
uint8_t *MemoryManager::MallocMem(MemType type, size_t size) {
uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
MS_EXCEPTION_IF_NULL(address);
uint8_t *ptr = nullptr;
if (type == kStaticMem) {
ptr = MallocStaticMem(size, false);
address->from_mem_pool_ = true;
} else if (type == kDynamicMem) {
ptr = MallocDynamicMem(size, false);
}
return ptr;
}
uint8_t *MemoryManager::MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size,
const session::KernelWithIndex &node_with_index) {
MS_EXCEPTION_IF_NULL(address);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
uint8_t *ptr = nullptr;
if (node_with_index.first != nullptr) {
ptr = MallocOutputMem(node_with_index.first, node_with_index.second, flag, size);
MS_EXCEPTION_IF_NULL(ptr);
if (AnfAlgo::IsCommunicationOp(node_with_index.first) && context_ptr->enable_hccl()) {
address->communication_ptr_ = ptr - kMemAlignSize;
}
} else {
ptr = MallocMem(flag, size);
MS_EXCEPTION_IF_NULL(ptr);
}
address->ptr_ = ptr;
if (flag == kStaticMem) {
address->from_mem_pool_ = true;
}
return ptr;
}

@ -41,11 +41,10 @@ class MemoryManager {
}
void MallocReusedDynamicMem(const session::KernelGraph *graph);
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
const DeviceAddressPtr &address);
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
uint8_t *MallocMem(const DeviceAddressPtr &address, MemType flag, size_t size,
const session::KernelWithIndex &node_with_index = std::pair<AnfNodePtr, size_t>(nullptr, 0));
virtual uint8_t *MallocMem(MemType type, size_t size);
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address);
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
virtual void *MallocMemFromMemPool(size_t size);

Loading…
Cancel
Save