|
|
|
@ -40,21 +40,24 @@ void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (allreduce_kernels.size() > 1) {
|
|
|
|
|
DeviceStream comm_stream = nullptr;
|
|
|
|
|
GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
|
|
|
|
|
std::transform(
|
|
|
|
|
allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) {
|
|
|
|
|
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), allreduce_kernel);
|
|
|
|
|
return allreduce_kernel;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Assign multiple streams only when there's Recv node for AllReduce.
|
|
|
|
|
std::vector<SendRecvPair> send_recv_pairs;
|
|
|
|
|
FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs);
|
|
|
|
|
InsertStreamSwitchNode(kernel_graph, send_recv_pairs);
|
|
|
|
|
if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) {
|
|
|
|
|
DeviceStream comm_stream = nullptr;
|
|
|
|
|
GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
|
|
|
|
|
std::transform(
|
|
|
|
|
allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) {
|
|
|
|
|
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), allreduce_kernel);
|
|
|
|
|
return allreduce_kernel;
|
|
|
|
|
});
|
|
|
|
|
InsertStreamSwitchNode(kernel_graph, send_recv_pairs);
|
|
|
|
|
} else {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
|
|
|
|
bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
|
|
|
|
std::vector<SendRecvPair> *send_recv_pairs) {
|
|
|
|
|
auto execution_kernels = kernel_graph->execution_order();
|
|
|
|
|
std::vector<CNodePtr>::iterator iter, iter_begin;
|
|
|
|
@ -77,14 +80,15 @@ void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
|
|
|
|
|
std::vector<CNodePtr>::iterator mock_recv_node_iter =
|
|
|
|
|
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
|
|
|
|
|
if (mock_recv_node_iter == iter_end) {
|
|
|
|
|
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
|
|
|
|
|
continue;
|
|
|
|
|
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
|
|
|
|
|
IntToSize(mock_recv_node_iter - iter_begin)};
|
|
|
|
|
send_recv_pairs->push_back(pair2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
|
|
|
|
|