fix broadcast bug

port
yi.wu 7 years ago
parent 88cb47bd86
commit 6f0107126a

@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs(
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
VLOG(3) << "run broadcast " << var << " " << var_dev_id;
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
if (initializing) {
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (static_cast<size_t>(var_dev_id)) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
}
member_->nccl_ctxs_->WaitAll();
}

Loading…
Cancel
Save