|
|
@ -68,9 +68,11 @@ struct CUBlas<float> {
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
<< (dev_ctx->tensor_core_available() ? "True" : "False");
|
|
|
|
<< (dev_ctx->tensor_core_available() ? "True" : "False");
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
|
|
|
|
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
|
|
|
|
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
|
|
|
|
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
|
|
|
|
|
|
|
|
beta, C, Ctype, ldc));
|
|
|
|
|
|
|
|
});
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
|
|
|
|
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
@ -171,10 +173,11 @@ struct CUBlas<platform::float16> {
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
#endif // CUDA_VERSION >= 9000
|
|
|
|
#endif // CUDA_VERSION >= 9000
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType,
|
|
|
|
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
|
|
|
|
algo));
|
|
|
|
beta, C, Ctype, ldc, computeType, algo));
|
|
|
|
|
|
|
|
});
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
|
|
|
|
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
@ -204,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
CUDA_R_32F, N);
|
|
|
|
CUDA_R_32F, N);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
|
|
|
|
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
|
|
|
|
&alpha, B, ldb, A, lda, &beta, C, N);
|
|
|
|
lda, &beta, C, N);
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -247,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
|
|
|
|
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
|
|
|
|
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
|
|
|
|
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
|
|
|
|
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
|
|
|
|
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
&h_beta, h_C, N);
|
|
|
|
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
|
|
|
|
|
|
|
|
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
|
|
|
|
|
|
|
|
N);
|
|
|
|
|
|
|
|
});
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -273,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
|
|
|
|
|
|
|
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
&alpha, B, ldb, A, lda, &beta, C, ldc);
|
|
|
|
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
|
|
|
|
|
|
|
|
lda, &beta, C, ldc);
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -292,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
|
|
|
|
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
|
|
|
|
|
|
|
|
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
N, M, K, &alpha, B, ldb, A, lda, &beta, C,
|
|
|
|
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
|
|
|
|
ldc);
|
|
|
|
B, ldb, A, lda, &beta, C, ldc);
|
|
|
|
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
|
|
|
|
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
|
|
|
|
T *y) const {
|
|
|
|
T *y) const {
|
|
|
|
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
|
|
|
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
|
|
|
|
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
@ -311,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
|
|
|
|
T beta, T *C) const {
|
|
|
|
T beta, T *C) const {
|
|
|
|
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
|
|
|
|
|
|
|
|
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
&beta, C, 1);
|
|
|
|
CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
|
|
|
|
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
@ -342,16 +355,20 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M,
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA,
|
|
|
|
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
|
|
|
|
&beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
|
|
|
|
|
|
|
|
strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
|
|
|
});
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#endif // CUDA_VERSION >= 9010
|
|
|
|
#endif // CUDA_VERSION >= 9010
|
|
|
|
|
|
|
|
|
|
|
|
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
N, M, K, &alpha, B, ldb, strideB, A, lda,
|
|
|
|
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
|
|
|
|
strideA, &beta, C, ldc, strideC, batchCount);
|
|
|
|
B, ldb, strideB, A, lda, strideA, &beta, C,
|
|
|
|
|
|
|
|
ldc, strideC, batchCount);
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9010
|
|
|
|
#if CUDA_VERSION >= 9010
|
|
|
|
}
|
|
|
|
}
|
|
|
|