|
|
@ -108,14 +108,13 @@ struct OpHandle {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) && events_.empty()) {
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
|
|
|
|
for (auto &dev_ctx : dev_ctx_) {
|
|
|
|
for (auto &dev_ctx : dev_ctx_) {
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
auto stream =
|
|
|
|
auto stream =
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
|
|
|
|
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
|
|
|
|
}
|
|
|
|
}
|
|
|
|