|
|
|
@ -110,23 +110,30 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
// Bcast Parameters to all GPUs
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
|
|
|
|
|
std::unique_ptr<ncclUniqueId> nccl_id = nullptr;
|
|
|
|
|
ncclUniqueId *nccl_id = nullptr;
|
|
|
|
|
bool need_group_call = true;
|
|
|
|
|
if (nccl_id_var != nullptr) {
|
|
|
|
|
nccl_id.reset(nccl_id_var->GetMutable<ncclUniqueId>());
|
|
|
|
|
} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
|
|
|
|
|
nccl_id.reset(new ncclUniqueId());
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id.get()));
|
|
|
|
|
*member_->global_scope_->Var(NCCL_ID_VARNAME)
|
|
|
|
|
->GetMutable<ncclUniqueId>() = *nccl_id.get();
|
|
|
|
|
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
|
|
|
|
|
// parallel graph mode should initialize nccl by ncclCommInitRank since
|
|
|
|
|
// it call nccl operator per device per thread.
|
|
|
|
|
if (nccl_id_var == nullptr) {
|
|
|
|
|
nccl_id = new ncclUniqueId();
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
|
|
|
|
|
*member_->global_scope_->Var(NCCL_ID_VARNAME)
|
|
|
|
|
->GetMutable<ncclUniqueId>() = *nccl_id;
|
|
|
|
|
} else {
|
|
|
|
|
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
|
|
|
|
|
}
|
|
|
|
|
need_group_call = false;
|
|
|
|
|
} else if (nccl_id_var != nullptr) { // the other executor type.
|
|
|
|
|
// the distributed training with nccl mode would initialize the nccl id in
|
|
|
|
|
// startup_program.
|
|
|
|
|
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
|
|
|
|
|
} else {
|
|
|
|
|
// init nccl_id in NCCLContextMap
|
|
|
|
|
// initlize NCCL by ncclCommInitAll, do not need nccl_id.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
|
|
|
|
|
member_->places_, nccl_id.get(), num_trainers, trainer_id,
|
|
|
|
|
need_group_call));
|
|
|
|
|
member_->places_, nccl_id, num_trainers, trainer_id, need_group_call));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA");
|
|
|
|
|
#endif
|
|
|
|
|