|
|
|
@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
|
|
|
|
|
|
|
|
|
|
struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
ParallelExecutorPrivate *member_;
|
|
|
|
|
std::vector<cudaEvent_t> events_;
|
|
|
|
|
|
|
|
|
|
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
|
|
|
|
|
: member_(member) {}
|
|
|
|
|
: member_(member) {
|
|
|
|
|
events_.resize(member_->places_.size());
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
|
cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~NCCLAllReduceOpHandle() {
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
|
cudaEventDestroy(ev);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Run() override {
|
|
|
|
|
if (this->inputs_.size() == 1) {
|
|
|
|
@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
platform::dynload::ncclAllReduce(
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
cudaEventRecord(events_[i], nccl_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
for (auto &pair : member_->communication_streams_) {
|
|
|
|
|
pair.second.ctx_->Wait();
|
|
|
|
|
if (platform::is_cpu_place(
|
|
|
|
|
waited_dev->GetPlace())) { // Wait by CPU, just sync stream
|
|
|
|
|
for (auto &pair : member_->communication_streams_) {
|
|
|
|
|
pair.second.ctx_->Wait();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (events_.size() > 1) {
|
|
|
|
|
auto stream =
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
|
cudaStreamWaitEvent(stream, ev, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|