|
|
|
@ -428,7 +428,8 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
|
|
|
|
|
const int64_t strideC = M * N;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9010
|
|
|
|
|
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
|
|
|
|
|
if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same<T, float>::value)) ||
|
|
|
|
|
std::is_same<T, paddle::platform::float16>::value) {
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
bool use_tensor_op_math = context_.tensor_core_available();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
@ -437,11 +438,11 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
|
|
|
|
|
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
|
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
|
|
|
|
|
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
|
|
|
|
|
strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A,
|
|
|
|
|
fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo));
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
#endif // CUDA_VERSION >= 9010
|
|
|
|
|