|
|
|
@ -73,7 +73,9 @@ struct NCCLContextMap {
|
|
|
|
|
std::unordered_map<int, NCCLContext> contexts_;
|
|
|
|
|
std::vector<int> order_;
|
|
|
|
|
|
|
|
|
|
explicit NCCLContextMap(const std::vector<platform::Place> &places) {
|
|
|
|
|
explicit NCCLContextMap(const std::vector<platform::Place> &places,
|
|
|
|
|
ncclUniqueId *nccl_id = nullptr,
|
|
|
|
|
size_t node_count = 0, size_t trainer_id = 0) {
|
|
|
|
|
PADDLE_ENFORCE(!places.empty());
|
|
|
|
|
order_.reserve(places.size());
|
|
|
|
|
for (auto &p : places) {
|
|
|
|
@ -85,18 +87,36 @@ struct NCCLContextMap {
|
|
|
|
|
order_.size(), contexts_.size(),
|
|
|
|
|
"NCCL Context Map does not support contain two or more same device");
|
|
|
|
|
|
|
|
|
|
if (places.size() > 1) {
|
|
|
|
|
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
|
|
|
|
|
if (places.size() <= 1) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
|
|
|
|
|
// if pass nccl_id here, can assume we are doing multi node training
|
|
|
|
|
if (nccl_id == nullptr) {
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
|
|
|
|
|
comms.get(), static_cast<int>(order_.size()), order_.data()));
|
|
|
|
|
}
|
|
|
|
|
int i = 0;
|
|
|
|
|
for (auto &dev_id : order_) {
|
|
|
|
|
contexts_.at(dev_id).comm_ = comms[i++];
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_GT(node_count, 0);
|
|
|
|
|
PADDLE_ENFORCE_EQ(node_count % places.size(), 0,
|
|
|
|
|
"must have same number of GPUs on each node");
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
|
|
|
|
|
int nranks = node_count * order_.size();
|
|
|
|
|
for (auto &gpu_id : order_) {
|
|
|
|
|
int rank = trainer_id * order_.size() + gpu_id;
|
|
|
|
|
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ncclCommInitRank(comms.get() + gpu_id, nranks, *nccl_id, rank));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int i = 0;
|
|
|
|
|
for (auto &dev_id : order_) {
|
|
|
|
|
contexts_.at(dev_id).comm_ = comms[i++];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NCCLContextMap(const NCCLContextMap &other) = delete;
|
|
|
|
|