|
|
|
@ -231,7 +231,6 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
|
|
|
|
|
ncclUniqueId *nccl_id = nullptr;
|
|
|
|
|
bool need_group_call = true;
|
|
|
|
|
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
|
|
|
|
|
// parallel graph mode should initialize nccl by ncclCommInitRank since
|
|
|
|
|
// it call nccl operator per device per thread.
|
|
|
|
@ -243,17 +242,16 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
} else {
|
|
|
|
|
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
|
|
|
|
|
}
|
|
|
|
|
need_group_call = false;
|
|
|
|
|
} else if (nccl_id_var != nullptr) { // the other executor type.
|
|
|
|
|
// the distributed training with nccl mode would initialize the nccl id in
|
|
|
|
|
// startup_program.
|
|
|
|
|
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
|
|
|
|
|
} else {
|
|
|
|
|
// initlize NCCL by ncclCommInitAll, do not need nccl_id.
|
|
|
|
|
// initlize NCCL by ncclCommInitAll, do not need to intialize the nccl_id.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
|
|
|
|
|
member_->places_, nccl_id, num_trainers, trainer_id, need_group_call));
|
|
|
|
|
member_->places_, nccl_id, num_trainers, trainer_id));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA");
|
|
|
|
|
#endif
|
|
|
|
@ -288,6 +286,14 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
graphs.push_back(std::move(graph));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
auto max_memory_size = GetEagerDeletionThreshold();
|
|
|
|
|
// TODO(Yancey1989): fix gc failed on ParallelGraph executor.
|
|
|
|
|
if (max_memory_size >= 0 &&
|
|
|
|
|
exec_strategy.type_ != ExecutionStrategy::kParallelGraph) {
|
|
|
|
|
graphs[0] = member_->PrepareGCAndRefCnts(
|
|
|
|
|
std::move(graphs[0]), static_cast<size_t>(max_memory_size));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 3. Create vars in each scope. Passes may also create new vars.
|
|
|
|
|
// skip control vars and empty vars
|
|
|
|
|
std::vector<details::VariableInfo> var_infos;
|
|
|
|
|