|
|
|
@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
|
|
|
|
|
template <typename Callback>
|
|
|
|
|
void RecordEvent(cudaEvent_t ev, Callback callback) {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> guard(mutex_);
|
|
|
|
|
callback();
|
|
|
|
|
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
|
|
|
|
|
}
|
|
|
|
@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
|
|
|
|
|
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
|
|
|
|
|
|
|
|
|
|
mutable std::recursive_mutex mutex_;
|
|
|
|
|
cudaStream_t stream_;
|
|
|
|
|
cudnnHandle_t cudnn_handle_;
|
|
|
|
|
cublasHandle_t cublas_handle_;
|
|
|
|
|