|
|
|
@ -166,7 +166,7 @@ void ParallelExecutor::BCastParamsToGPUs(
|
|
|
|
|
void *buffer;
|
|
|
|
|
|
|
|
|
|
if ((initializing && i == 0) ||
|
|
|
|
|
(!initializing && i == static_cast<size_t>(var_dev_id))) {
|
|
|
|
|
(!initializing && static_cast<int>(i) == var_dev_id)) {
|
|
|
|
|
buffer = const_cast<void *>(main_tensor.data<void>());
|
|
|
|
|
} else {
|
|
|
|
|
auto local_scope = member_->local_scopes_[i];
|
|
|
|
@ -187,7 +187,7 @@ void ParallelExecutor::BCastParamsToGPUs(
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
|
|
|
|
|
nccl_ctx.comm_, nccl_ctx.stream());
|
|
|
|
|
} else {
|
|
|
|
|
if (static_cast<size_t>(var_dev_id)) {
|
|
|
|
|
if (var_dev_id >= 0) {
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type,
|
|
|
|
|
var_dev_id, nccl_ctx.comm_,
|
|
|
|
|
nccl_ctx.stream());
|
|
|
|
|