|
|
|
@ -288,9 +288,14 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
|
"cublas Hgemm requires GPU compute capability >= 53");
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
|
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -310,9 +315,13 @@ void batched_gemm<platform::CUDADeviceContext, float>(
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
|
|
|
|
|
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -332,9 +341,13 @@ void batched_gemm<platform::CUDADeviceContext, double>(
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
|
|
|
|
|
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|