|
|
|
@ -320,14 +320,14 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
|
|
|
|
|
: member_(member) {
|
|
|
|
|
for (auto &nccl : member_->communication_streams_) {
|
|
|
|
|
cudaEventCreate(&events_[nccl.second.device_id()],
|
|
|
|
|
cudaEventDisableTiming);
|
|
|
|
|
PADDLE_ENFORCE(cudaEventCreate(&events_[nccl.second.device_id()],
|
|
|
|
|
cudaEventDisableTiming));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~NCCLAllReduceOpHandle() {
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
|
cudaEventDestroy(ev.second);
|
|
|
|
|
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
platform::dynload::ncclAllReduce(
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
cudaEventRecord(events_[dev_id], nccl_ctx.stream());
|
|
|
|
|
PADDLE_ENFORCE(cudaEventRecord(events_[dev_id], nccl_ctx.stream()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
@ -381,7 +381,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
boost::get<platform::CUDAPlace>(waited_dev->GetPlace()).device;
|
|
|
|
|
auto stream =
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
|
|
|
|
cudaStreamWaitEvent(stream, events_[dev_id], 0);
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events_[dev_id], 0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|