|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "device/kernel_runtime.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <numeric>
|
|
|
|
|
#include <functional>
|
|
|
|
@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|
|
|
|
mem_manager_->ResetDynamicMemory();
|
|
|
|
|
AssignStaticMemory(graph);
|
|
|
|
|
AssignDynamicMemory(graph);
|
|
|
|
|
|
|
|
|
|
UpdateRefNodeOutputMem(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
|
|
|
|
session::KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
// assign memory for input nodes
|
|
|
|
|
RunOpAssignInputMemory(input_tensors, graph);
|
|
|
|
|
AssignStaticMemoryValueNode(graph);
|
|
|
|
|
for (const auto &cnode : graph->execution_order()) {
|
|
|
|
|
// assign memory for output nodes
|
|
|
|
|
RunOpAssignOutputMemory(cnode);
|
|
|
|
|
// assign memory for workspace
|
|
|
|
|
RunOpAssignWorkSpaceMemory(cnode);
|
|
|
|
|
}
|
|
|
|
|
UpdateRefNodeOutputMem(graph);
|
|
|
|
@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|
|
|
|
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
|
|
|
|
std::vector<session::KernelWithIndex> non_communication_op;
|
|
|
|
|
// Assign Communicate Op Memory firstly.
|
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
|
|
|
|
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
|
|
|
|
|
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
|
|
|
|
|
} else {
|
|
|
|
|
non_communication_op.emplace_back(item_with_index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto &item_with_index : non_communication_op) {
|
|
|
|
|
AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
|
|
|
|
|
AssignCommunicationNodeInputMem(node);
|
|
|
|
|
AssignCommunicationNodeOutputMem(flag, node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
|
|
|
@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
size_t total_size = 0;
|
|
|
|
|
size_t output_index = 0;
|
|
|
|
|
std::vector<size_t> align_size_list;
|
|
|
|
|
for (uint64_t mem_size : output_sizes) {
|
|
|
|
|
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
|
|
|
|
|
MS_LOG(INFO) << "communication op addr exist";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (context_ptr->enable_hccl()) {
|
|
|
|
|
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
|
|
|
|
}
|
|
|
|
@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
|
|
|
|
|
DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
|
|
|
|
|
auto output_sizes = kernel_mod->GetOutputSizeList();
|
|
|
|
|
if (output_sizes.size() <= index) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Previous node output size < node index";
|
|
|
|
|
}
|
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
|
|
|
|
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
|
|
|
|
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
|
|
|
|
|
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
|
|
|
|
|
return address;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
|
|
|
|
|
size_t total_size = 0;
|
|
|
|
|
std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
|
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
|
|
|
|
|
auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(address);
|
|
|
|
|
auto mem_size = address->size();
|
|
|
|
|
if (context_ptr->enable_hccl()) {
|
|
|
|
|
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
|
|
|
|
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
|
|
|
|
auto input_node = input_node_with_index.first;
|
|
|
|
|
DeviceAddressPtr address = nullptr;
|
|
|
|
|
if (input_node->isa<CNode>()) {
|
|
|
|
|
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(address);
|
|
|
|
|
auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
|
|
|
|
|
total_size += mem_size;
|
|
|
|
|
addr_size.emplace_back(address.get(), mem_size);
|
|
|
|
|
}
|
|
|
|
@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
|
|
|
|
|
void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
|
|
|
|
if (AnfAlgo::IsCommunicationOp(node)) {
|
|
|
|
|
UpdateCommunicationOpInputMem(node);
|
|
|
|
|
AssignCommunicationNodeOutputMem(flag, node);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
|
|
|
|
|
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
|
|
|
|
flag = kDynamicMem;
|
|
|
|
@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
|
|
|
|
mem_manager_->MallocReusedDynamicMem(graph);
|
|
|
|
|
mem_flag = kReuseDynamicMem;
|
|
|
|
|
}
|
|
|
|
|
auto &kernels = graph->execution_order();
|
|
|
|
|
for (auto &kernel : kernels) {
|
|
|
|
|
AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts);
|
|
|
|
|
AssignWorkSpaceMem(mem_flag, kernel);
|
|
|
|
|
auto &execution_nodes = graph->execution_order();
|
|
|
|
|
std::vector<CNodePtr> compute_nodes;
|
|
|
|
|
// communication nodes first
|
|
|
|
|
for (auto &node : execution_nodes) {
|
|
|
|
|
if (AnfAlgo::IsCommunicationOp(node)) {
|
|
|
|
|
// skip if the memory is already alocated
|
|
|
|
|
AssignCommunicationNodeMem(mem_flag, node);
|
|
|
|
|
} else {
|
|
|
|
|
compute_nodes.emplace_back(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// then compute nodes
|
|
|
|
|
for (auto &node : compute_nodes) {
|
|
|
|
|
AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
|
|
|
|
|
AssignWorkSpaceMem(mem_flag, node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|