update by comment

revert-15207-remove_op_handle_lock_and_fix_var
Yancey1989 7 years ago
parent 82726402be
commit 5cc83f79bf

@ -110,23 +110,30 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
std::unique_ptr<ncclUniqueId> nccl_id = nullptr; ncclUniqueId *nccl_id = nullptr;
bool need_group_call = true; bool need_group_call = true;
if (nccl_id_var != nullptr) { if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
nccl_id.reset(nccl_id_var->GetMutable<ncclUniqueId>()); // parallel graph mode should initialize nccl by ncclCommInitRank since
} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) { // it call nccl operator per device per thread.
nccl_id.reset(new ncclUniqueId()); if (nccl_id_var == nullptr) {
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id.get())); nccl_id = new ncclUniqueId();
*member_->global_scope_->Var(NCCL_ID_VARNAME) PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
->GetMutable<ncclUniqueId>() = *nccl_id.get(); *member_->global_scope_->Var(NCCL_ID_VARNAME)
->GetMutable<ncclUniqueId>() = *nccl_id;
} else {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
need_group_call = false; need_group_call = false;
} else if (nccl_id_var != nullptr) { // the other executor type.
// the distributed training with nccl mode would initialize the nccl id in
// startup_program.
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
} else { } else {
// init nccl_id in NCCLContextMap // initlize NCCL by ncclCommInitAll, do not need nccl_id.
} }
member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id.get(), num_trainers, trainer_id, member_->places_, nccl_id, num_trainers, trainer_id, need_group_call));
need_group_call));
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif

Loading…
Cancel
Save