|
|
|
@ -358,7 +358,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
|
|
|
|
|
|
|
|
|
|
cudaSetDevice(dev_id);
|
|
|
|
|
platform::dynload::ncclAllReduce(
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
@ -519,7 +519,6 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
var.name_ = og;
|
|
|
|
|
var.version_ = vars.size() - 1;
|
|
|
|
|
op_handle->outputs_.emplace_back(&var);
|
|
|
|
|
|
|
|
|
|
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|