|
|
|
@ -39,13 +39,14 @@ void gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
|
|
|
|
|
float h_alpha = static_cast<float>(alpha);
|
|
|
|
|
float h_beta = static_cast<float>(beta);
|
|
|
|
|
|
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
|
"cublas fp16 gemm requires GPU compute capability >= 53");
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
float h_alpha = static_cast<float>(alpha);
|
|
|
|
|
float h_beta = static_cast<float>(beta);
|
|
|
|
|
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
if (context.GetComputeCapability() >= 70) {
|
|
|
|
@ -56,7 +57,7 @@ void gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
|
|
|
|
|
CUBLAS_DEFAULT_MATH));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#endif // CUDA_VERSION >= 9000
|
|
|
|
|
|
|
|
|
|
// cublasHgemm does true FP16 computation which is slow for non-Volta
|
|
|
|
|
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
|
|
|
|
@ -66,6 +67,18 @@ void gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
|
|
|
|
|
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
|
|
|
|
|
CUDA_R_32F, algo));
|
|
|
|
|
#else
|
|
|
|
|
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
|
|
|
|
|
const half h_alpha = static_cast<const half>(alpha);
|
|
|
|
|
const half h_beta = static_cast<const half>(beta);
|
|
|
|
|
const half* h_A = reinterpret_cast<const half*>(A);
|
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
|
h_A, lda, &h_beta, h_C, N));
|
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|