helinwang-patch-1
Yu Yang 7 years ago
parent 99f85a9fbc
commit b94ffacbd7

@ -132,12 +132,12 @@ struct ScaleLossGradOpHandle : public OpHandle {
scope_(scope),
place_(place) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
// Must set device before create event
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
VLOG(3) << "Create " << ev_;
}
~ScaleLossGradOpHandle() {
VLOG(3) << "Destroy " << ev_;
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventDestroy(ev_));
}
@ -339,13 +339,15 @@ struct NCCLAllReduceOpHandle : public OpHandle {
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {
for (auto &nccl : member_->communication_streams_) {
PADDLE_ENFORCE(cudaEventCreate(&events_[nccl.second.device_id()],
cudaEventDisableTiming));
int dev_id = nccl.second.device_id();
cudaSetDevice(dev_id);
PADDLE_ENFORCE(cudaEventCreate(&events_[dev_id], cudaEventDisableTiming));
}
}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaSetDevice(ev.first);
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
}
}

Loading…
Cancel
Save