Follow comments

trainerSaveLoadParams
Yu Yang 7 years ago
parent 4db43c6c9f
commit caa4027d9d

@ -126,14 +126,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_32F, algo));
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta);
const half *h_A = reinterpret_cast<const half *>(A);
const half *h_B = reinterpret_cast<const half *>(B);
half *h_C = reinterpret_cast<half *>(C);
CUBlas<platform::float16>(context_.cublas_handle(), cuTransB, cuTransA, N, M,
K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N);
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
&h_beta, h_C, N);
#endif // CUDA_VERSION >= 8000
}

Loading…
Cancel
Save