|
|
|
@ -268,6 +268,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float16 alpha, const float16* A, const float16* B, const float16 beta,
|
|
|
|
|
float16* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
@ -289,7 +290,6 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
|
"cublas Hgemm requires GPU compute capability >= 53");
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
|
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
|
|
|
|
@ -304,6 +304,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float alpha, const float* A, const float* B, const float beta,
|
|
|
|
|
float* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
@ -315,7 +316,6 @@ void batched_gemm<platform::CUDADeviceContext, float>(
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
|
|
|
|
|
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
|
|
|
|
@ -330,6 +330,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const double alpha, const double* A, const double* B, const double beta,
|
|
|
|
|
double* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
@ -341,7 +342,6 @@ void batched_gemm<platform::CUDADeviceContext, double>(
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
|
|
|
|
|
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
|
|
|
|
|