helinwang-patch-1
Yu Yang 8 years ago
parent 99f85a9fbc
commit b94ffacbd7

@ -132,12 +132,12 @@ struct ScaleLossGradOpHandle : public OpHandle {
scope_(scope), scope_(scope),
place_(place) { place_(place) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device); cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
// Must set device before create event
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
VLOG(3) << "Create " << ev_;
} }
~ScaleLossGradOpHandle() { ~ScaleLossGradOpHandle() {
VLOG(3) << "Destroy " << ev_; cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventDestroy(ev_)); PADDLE_ENFORCE(cudaEventDestroy(ev_));
} }
@ -339,13 +339,15 @@ struct NCCLAllReduceOpHandle : public OpHandle {
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) { : member_(member) {
for (auto &nccl : member_->communication_streams_) { for (auto &nccl : member_->communication_streams_) {
PADDLE_ENFORCE(cudaEventCreate(&events_[nccl.second.device_id()], int dev_id = nccl.second.device_id();
cudaEventDisableTiming)); cudaSetDevice(dev_id);
PADDLE_ENFORCE(cudaEventCreate(&events_[dev_id], cudaEventDisableTiming));
} }
} }
~NCCLAllReduceOpHandle() { ~NCCLAllReduceOpHandle() {
for (auto &ev : events_) { for (auto &ev : events_) {
cudaSetDevice(ev.first);
PADDLE_ENFORCE(cudaEventDestroy(ev.second)); PADDLE_ENFORCE(cudaEventDestroy(ev.second));
} }
} }

Loading…
Cancel
Save