|
|
|
@ -8,27 +8,27 @@ NCCLManager::NCCLManager() {}
|
|
|
|
|
|
|
|
|
|
NCCLManager::~NCCLManager() {
|
|
|
|
|
for (auto& p : comm_table) {
|
|
|
|
|
auto* comm = p.second;
|
|
|
|
|
auto& comm = p.second;
|
|
|
|
|
auto& gpus_ = comm->gpus_;
|
|
|
|
|
for (int i = 0; i < gpus_.size(); ++i) {
|
|
|
|
|
for (size_t i = 0; i < gpus_.size(); ++i) {
|
|
|
|
|
int gid = gpus_[i];
|
|
|
|
|
platform::SetDeviceId(gid);
|
|
|
|
|
|
|
|
|
|
// mapping gid to idx
|
|
|
|
|
int idx = gid % gpus_.size();
|
|
|
|
|
// wait finish
|
|
|
|
|
NCCL_CHECK(
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0));
|
|
|
|
|
|
|
|
|
|
NCCL_CHECK(cudaEventDestroy(comm->events_[idx]));
|
|
|
|
|
PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
|
|
|
|
|
|
|
|
|
|
NCCL_CHECK(ncclCommDestroy(comm->comms_[idx]));
|
|
|
|
|
PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
|
|
|
|
|
}
|
|
|
|
|
delete comm;
|
|
|
|
|
comm.reset(nullptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
|
|
|
|
|
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) {
|
|
|
|
|
std::string key;
|
|
|
|
|
for (auto& id : gpus) {
|
|
|
|
|
key += std::to_string(id);
|
|
|
|
@ -37,21 +37,24 @@ Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
|
|
|
|
|
|
|
|
|
|
std::mutex mu;
|
|
|
|
|
std::lock_guard<std::mutex> lk(mu);
|
|
|
|
|
auto* comm = comm_table[key];
|
|
|
|
|
if (comm == nullptr) {
|
|
|
|
|
comm = new Communicator(gpus.size());
|
|
|
|
|
NCCL_CHECK(ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
|
|
|
|
|
|
|
|
|
|
auto it = comm_table.find(key);
|
|
|
|
|
|
|
|
|
|
if (it->second == nullptr) {
|
|
|
|
|
auto* comm = new Communicator(gpus);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < gpus.size(); ++i) {
|
|
|
|
|
platform::SetDeviceId(gpus[i]);
|
|
|
|
|
|
|
|
|
|
// block wait
|
|
|
|
|
NCCL_CHECK(cudaEventCreateWithFlags(
|
|
|
|
|
&events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
|
|
|
|
|
PADDLE_ENFORCE(cudaEventCreateWithFlags(
|
|
|
|
|
&comm->events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
|
|
|
|
|
}
|
|
|
|
|
comm_table[key] = comm;
|
|
|
|
|
comm_table[key].reset(comm);
|
|
|
|
|
}
|
|
|
|
|
return comm;
|
|
|
|
|
return comm_table[key].get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|