|
|
@ -454,8 +454,8 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|
|
|
std::vector<size_t> align_size_list;
|
|
|
|
std::vector<size_t> align_size_list;
|
|
|
|
for (uint64_t mem_size : output_sizes) {
|
|
|
|
for (uint64_t mem_size : output_sizes) {
|
|
|
|
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
|
|
|
|
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
|
|
|
|
MS_LOG(INFO) << "communication op addr exist";
|
|
|
|
MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
|
|
|
|
continue;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
|
|
|
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
|
|
|
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
|
|
|
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
|
|
@ -464,6 +464,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|
|
|
align_size_list.emplace_back(mem_size);
|
|
|
|
align_size_list.emplace_back(mem_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (align_size_list.empty()) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (type == kReuseDynamicMem) {
|
|
|
|
if (type == kReuseDynamicMem) {
|
|
|
|
// reuse communication op's all outputs' memory
|
|
|
|
// reuse communication op's all outputs' memory
|
|
|
|
type = kReuseDynamicCommMem;
|
|
|
|
type = kReuseDynamicCommMem;
|
|
|
@ -533,6 +537,10 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
|
|
|
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
|
|
|
auto input_node = input_node_with_index.first;
|
|
|
|
auto input_node = input_node_with_index.first;
|
|
|
|
|
|
|
|
if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
DeviceAddressPtr address = nullptr;
|
|
|
|
DeviceAddressPtr address = nullptr;
|
|
|
|
if (input_node->isa<CNode>()) {
|
|
|
|
if (input_node->isa<CNode>()) {
|
|
|
|
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
|
|
|
|
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
|
|
|
@ -811,6 +819,10 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
|
|
|
size_t index = 0;
|
|
|
|
size_t index = 0;
|
|
|
|
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
|
|
|
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
|
|
|
|
|
|
|
if (AnfAlgo::WorkspaceAddrExist(node, index)) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
|
|
|
|
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
|
|
|
|
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
|
|
|
|
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
|
|
|
|
index++;
|
|
|
|
index++;
|
|
|
|