|
|
|
@ -250,6 +250,8 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
int dtype = -1;
|
|
|
|
|
size_t numel = 0;
|
|
|
|
|
|
|
|
|
|
platform::dynload::ncclGroupStart();
|
|
|
|
|
|
|
|
|
|
for (auto &p : member_->places_) {
|
|
|
|
|
int dev_id = boost::get<platform::CUDAPlace>(p).device;
|
|
|
|
|
|
|
|
|
@ -266,11 +268,12 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
|
|
|
|
|
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
|
|
|
|
|
|
|
|
|
|
ncclAllReduce(buffer, buffer, numel, static_cast<ncclDataType_t>(dtype),
|
|
|
|
|
ncclSum, nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
platform::dynload::ncclAllReduce(
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ncclGroupEnd();
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|