|
|
|
@ -67,7 +67,6 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
|
|
|
|
|
comm_vec_.push_back(comm);
|
|
|
|
|
|
|
|
|
|
auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id);
|
|
|
|
|
|
|
|
|
@ -89,7 +88,6 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
|
|
|
|
|
ncclComm_t comms[kDevices];
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
|
|
|
|
|
comms, dev_ids.size(), dev_ids.data()));
|
|
|
|
|
comm_vec_.insert(comm_vec_.end(), comms, comms + kDevices);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0);
|
|
|
|
|
for (size_t i = 0; i < dev_ids.size(); ++i) {
|
|
|
|
@ -135,10 +133,10 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void NCCLCommContext::ReleaseNCCLComms() {
|
|
|
|
|
for (auto comm : comm_vec_) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::ncclCommDestroy(comm),
|
|
|
|
|
platform::errors::External("Fail to destroy nccl comm"));
|
|
|
|
|
for (auto& p : comm_map_) {
|
|
|
|
|
for (auto& q : p.second) {
|
|
|
|
|
q.second.reset();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|