|
|
|
@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs(
|
|
|
|
|
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "run broadcast " << var << " " << var_dev_id;
|
|
|
|
|
|
|
|
|
|
auto &main_tensor = main_var->Get<LoDTensor>();
|
|
|
|
|
auto &dims = main_tensor.dims();
|
|
|
|
@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
|
|
|
|
|
platform::NCCLGroupGuard guard;
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
|
|
|
|
|
nccl_ctx.comm_, nccl_ctx.stream());
|
|
|
|
|
if (initializing) {
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
|
|
|
|
|
nccl_ctx.comm_, nccl_ctx.stream());
|
|
|
|
|
} else {
|
|
|
|
|
if (static_cast<size_t>(var_dev_id)) {
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type,
|
|
|
|
|
var_dev_id, nccl_ctx.comm_,
|
|
|
|
|
nccl_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
member_->nccl_ctxs_->WaitAll();
|
|
|
|
|
}
|
|
|
|
|