Add cublas_handle() to expose cublas_handle to ops (#31157)

* add get_cublas_handle() api

* update format

* add unittests

* alter function name
revert-31068-fix_conv3d_windows
liu zhengxi 4 years ago committed by GitHub
parent 406f4a7513
commit ae2be49f40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,6 +108,8 @@ class CublasHandleHolder {
} }
#endif #endif
const cublasHandle_t& GetCublasHandle() const { return handle_; }
~CublasHandleHolder() PADDLE_MAY_THROW { ~CublasHandleHolder() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_)); PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_));
@ -117,7 +119,7 @@ class CublasHandleHolder {
} }
template <typename Callback> template <typename Callback>
inline void Call(Callback &&callback) const { inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
callback(handle_); callback(handle_);
} }

@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle(); return context()->CudnnHandle();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
} }

@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
#endif #endif
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn /*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads. * functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock * Once the first cudnn function is called by the handle, a lock

@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) {
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
#endif #endif
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context; delete device_context;
} }
} }

Loading…
Cancel
Save