|
|
|
@ -126,14 +126,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
|
|
|
|
|
CUDA_R_32F, algo));
|
|
|
|
|
#else
|
|
|
|
|
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
|
|
|
|
|
const half h_alpha = static_cast<const half>(alpha);
|
|
|
|
|
const half h_beta = static_cast<const half>(beta);
|
|
|
|
|
const half *h_A = reinterpret_cast<const half *>(A);
|
|
|
|
|
const half *h_B = reinterpret_cast<const half *>(B);
|
|
|
|
|
half *h_C = reinterpret_cast<half *>(C);
|
|
|
|
|
|
|
|
|
|
CUBlas<platform::float16>(context_.cublas_handle(), cuTransB, cuTransA, N, M,
|
|
|
|
|
K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N);
|
|
|
|
|
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
|
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
|
|
|
|
|
&h_beta, h_C, N);
|
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|