Use event to sync stream

helinwang-patch-1
Yu Yang 7 years ago
parent 3aa7051b98
commit d7badb3ed2

@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
std::vector<cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {}
: member_(member) {
events_.resize(member_->places_.size());
for (auto &ev : events_) {
cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
}
}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaEventDestroy(ev);
}
}
void Run() override {
if (this->inputs_.size() == 1) {
@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
cudaEventRecord(events_[i], nccl_ctx.stream());
}
platform::dynload::ncclGroupEnd();
@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
void Wait(platform::DeviceContext *waited_dev) override {
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
if (platform::is_cpu_place(
waited_dev->GetPlace())) { // Wait by CPU, just sync stream
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
}
} else {
if (events_.size() > 1) {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
cudaStreamWaitEvent(stream, ev, 0);
}
}
}
}
};

Loading…
Cancel
Save