helinwang-patch-1
Yu Yang 7 years ago
parent 8f0590e7c5
commit e8a7e5d1e6

@ -250,6 +250,8 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1;
size_t numel = 0;
platform::dynload::ncclGroupStart();
for (auto &p : member_->places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
@ -266,11 +268,12 @@ struct NCCLAllReduceOpHandle : public OpHandle {
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
ncclAllReduce(buffer, buffer, numel, static_cast<ncclDataType_t>(dtype),
ncclSum, nccl_ctx.comm, nccl_ctx.stream());
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
}
ncclGroupEnd();
platform::dynload::ncclGroupEnd();
}
}
};

Loading…
Cancel
Save