|
|
|
@ -188,26 +188,16 @@ bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
|
|
|
|
|
callback();
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (!events_.empty()) { // Use event
|
|
|
|
|
std::function<void()> method = callback;
|
|
|
|
|
for (auto &p : dev_ctxes_) {
|
|
|
|
|
method = [method, p, this]() {
|
|
|
|
|
VLOG(10) << "cudadevicecontext:"
|
|
|
|
|
<< static_cast<platform::CUDADeviceContext *>(p.second)
|
|
|
|
|
<< ", dev_id:"
|
|
|
|
|
<< boost::get<platform::CUDAPlace>(p.first).device;
|
|
|
|
|
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
|
|
|
|
|
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
|
|
|
|
|
method);
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
method();
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
callback();
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto dev_id = boost::get<platform::CUDAPlace>(p.first).device;
|
|
|
|
|
auto *cuda_dev_ctx = static_cast<platform::CUDADeviceContext *>(p.second);
|
|
|
|
|
VLOG(10) << "cudadevicecontext:" << cuda_dev_ctx << ", dev_id:" << dev_id;
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
cudaEventRecord(events_.at(dev_id), cuda_dev_ctx->stream()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|