|
|
|
@ -109,18 +109,18 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
|
|
|
|
|
|
|
|
|
|
void OpHandleBase::RunAndRecordEvent(platform::Place p,
|
|
|
|
|
const std::function<void()> &callback) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_cpu_place(p) || events_.empty()) {
|
|
|
|
|
callback();
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *ctx = dev_ctxes_.at(p);
|
|
|
|
|
auto *cuda_ctx = static_cast<platform::CUDADeviceContext *>(ctx);
|
|
|
|
|
cuda_ctx->RecordEvent(events_.at(boost::get<platform::CUDAPlace>(p).device),
|
|
|
|
|
callback);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not implemented");
|
|
|
|
|
callback();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|