|
|
|
@ -402,10 +402,13 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
platform::dynload::ncclAllReduce(
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
PADDLE_ENFORCE(cudaEventRecord(events_[dev_id], nccl_ctx.stream()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
|
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
|
PADDLE_ENFORCE(cudaEventRecord(
|
|
|
|
|
ev.second, member_->communication_streams_.at(ev.first).stream()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|