|
|
|
@ -29,7 +29,14 @@ namespace platform {
|
|
|
|
|
class CublasHandleHolder {
|
|
|
|
|
public:
|
|
|
|
|
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cublasCreate(&handle_));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
dynload::cublasCreate(&handle_),
|
|
|
|
|
platform::errors::External(
|
|
|
|
|
"The cuBLAS library was not initialized. This is usually caused by "
|
|
|
|
|
"an error in the CUDA Runtime API called by the cuBLAS routine, or "
|
|
|
|
|
"an error in the hardware setup.\n"
|
|
|
|
|
"To correct: check that the hardware, an appropriate version of "
|
|
|
|
|
"the driver, and the cuBLAS library are correctly installed."));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cublasSetStream(handle_, stream));
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
if (math_type == CUBLAS_TENSOR_OP_MATH) {
|
|
|
|
|