SetDev for nccl

helinwang-patch-1
Yu Yang 7 years ago
parent d7badb3ed2
commit 29cc9f308d

@ -358,7 +358,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
cudaSetDevice(dev_id);
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
@ -519,7 +519,6 @@ void ParallelExecutor::ConstructDependencyGraph(
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var);
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
}
}

Loading…
Cancel
Save