|
|
@ -39,18 +39,33 @@ void gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
|
|
|
|
|
|
|
const half h_alpha = static_cast<const half>(alpha);
|
|
|
|
float h_alpha = static_cast<float>(alpha);
|
|
|
|
const half h_beta = static_cast<const half>(beta);
|
|
|
|
float h_beta = static_cast<float>(beta);
|
|
|
|
const half* h_A = reinterpret_cast<const half*>(A);
|
|
|
|
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
"cublas Hgemm requires GPU compute capability >= 53");
|
|
|
|
"cublas fp16 gemm requires GPU compute capability >= 53");
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
|
|
|
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
h_A, lda, &h_beta, h_C, N));
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
|
|
|
if (context.GetComputeCapability() >= 70) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
|
|
|
|
|
|
|
|
CUBLAS_TENSOR_OP_MATH));
|
|
|
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
|
|
|
|
|
|
|
|
CUBLAS_DEFAULT_MATH));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// cublasHgemm does true FP16 computation which is slow for non-Volta
|
|
|
|
|
|
|
|
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
|
|
|
|
|
|
|
|
// input/output in fp16, computation in fp32, which can also be accelerated
|
|
|
|
|
|
|
|
// using tensor cores in volta GPUs.
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
|
|
|
|
|
|
|
|
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
|
|
|
|
|
|
|
|
CUDA_R_32F, algo));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|