|
|
|
|
@ -31,23 +31,24 @@ template <>
|
|
|
|
|
struct CUBlas<float> {
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM(ARGS... args) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void AXPY(ARGS... args) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMV(ARGS... args) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_STRIDED_BATCH(ARGS... args) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cublasSgemmStridedBatched(args...));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
|
|
|
|
|
#endif
|
|
|
|
|
@ -69,7 +70,7 @@ struct CUBlas<float> {
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (dev_ctx->tensor_core_available() ? "True" : "False");
|
|
|
|
|
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx(
|
|
|
|
|
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
|
|
|
|
|
beta, C, Ctype, ldc));
|
|
|
|
|
});
|
|
|
|
|
@ -83,23 +84,24 @@ template <>
|
|
|
|
|
struct CUBlas<double> {
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM(ARGS... args) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void AXPY(ARGS... args) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMV(ARGS... args) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_STRIDED_BATCH(ARGS... args) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cublasDgemmStridedBatched(args...));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
|
|
|
|
|
#endif
|
|
|
|
|
@ -120,7 +122,7 @@ struct CUBlas<platform::float16> {
|
|
|
|
|
const float16 *alpha, const float16 *A, int lda,
|
|
|
|
|
const float16 *B, int ldb, const float16 *beta, float16 *C,
|
|
|
|
|
int ldc) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
|
|
|
|
|
reinterpret_cast<const __half *>(alpha),
|
|
|
|
|
reinterpret_cast<const __half *>(A), lda,
|
|
|
|
|
@ -140,7 +142,7 @@ struct CUBlas<platform::float16> {
|
|
|
|
|
long long int strideC, // NOLINT
|
|
|
|
|
int batchCount) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched(
|
|
|
|
|
handle, transa, transb, m, n, k,
|
|
|
|
|
reinterpret_cast<const __half *>(alpha),
|
|
|
|
|
reinterpret_cast<const __half *>(A), lda, strideA,
|
|
|
|
|
@ -174,7 +176,7 @@ struct CUBlas<platform::float16> {
|
|
|
|
|
#endif // CUDA_VERSION >= 9000
|
|
|
|
|
|
|
|
|
|
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
|
|
|
|
|
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
|
|
|
|
|
beta, C, Ctype, ldc, computeType, algo));
|
|
|
|
|
});
|
|
|
|
|
@ -356,7 +358,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
|
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
|
|
|
|
|
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
|
|
|
|
|
strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
|