|
|
|
@ -281,6 +281,27 @@ bool ParallelExecutor::NeedCreateLocalExeScope() {
|
|
|
|
|
return executor && executor->NeedCreateLocalExeScope();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
/*
|
|
|
|
|
* When nccl inits nccl comm using ncclCommInitAll, it meets error when
|
|
|
|
|
* allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
|
|
|
|
|
* create a new nccl comm for sync_batch_norm_op. And these codes should be
|
|
|
|
|
* polished with a unified nccl management.
|
|
|
|
|
*/
|
|
|
|
|
platform::NCCLContextMap *ParallelExecutor::GetNCCLContextForSyncbatchNomrOp(
|
|
|
|
|
framework::Scope *scope) {
|
|
|
|
|
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
|
|
|
|
|
if (nccl_id_var != nullptr) {
|
|
|
|
|
return member_->nccl_ctxs_.DefaultFlatCtx();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dev_nccl_ctxs_.get() == nullptr) {
|
|
|
|
|
dev_nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
|
|
|
|
|
}
|
|
|
|
|
return dev_nccl_ctxs_.get();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
|
|
|
|
|
const std::vector<std::string> &bcast_vars,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
@ -357,13 +378,13 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
|
|
|
|
|
// NOTE: NCCL group-calls and non-group-calls can not use the same
|
|
|
|
|
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
|
|
|
|
|
// same communicators.
|
|
|
|
|
auto *nccl_ctxs = GetNCCLContextForSyncbatchNomrOp(scope);
|
|
|
|
|
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
|
|
|
|
|
platform::DeviceContextPool &pool =
|
|
|
|
|
platform::DeviceContextPool::Instance();
|
|
|
|
|
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
|
|
|
|
|
pool.Get(member_->places_[dev_id]));
|
|
|
|
|
auto &nccl_ctx =
|
|
|
|
|
member_->nccl_ctxs_.DefaultFlatCtx()->at(member_->places_[dev_id]);
|
|
|
|
|
auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]);
|
|
|
|
|
dev_ctx->set_nccl_comm(nccl_ctx.comm());
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|