|
|
|
@ -345,8 +345,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
VLOG(3) << "Wait NCCL AllReduce";
|
|
|
|
|
this->dev_ctx_.at(waited_dev->GetPlace())->Wait();
|
|
|
|
|
for (auto &pair : member_->communication_streams_) {
|
|
|
|
|
pair.second.ctx_->Wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|