|
|
|
@ -110,6 +110,8 @@ Event::Event(EventType type, std::string name, uint32_t thread_id,
|
|
|
|
|
has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false;
|
|
|
|
|
if (has_cuda_) {
|
|
|
|
|
auto* cuda_dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctx);
|
|
|
|
|
PADDLE_ENFORCE(cudaSetDevice(
|
|
|
|
|
boost::get<platform::CUDAPlace>(cuda_dev_ctx->GetPlace()).device));
|
|
|
|
|
PADDLE_ENFORCE(cudaGetDevice(&device_));
|
|
|
|
|
PADDLE_ENFORCE(cudaEventCreate(&event_));
|
|
|
|
|
auto stream = cuda_dev_ctx->stream();
|
|
|
|
|