|
|
@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
|
|
|
|
"cublas Hgemm requires GPU compute capability >= 53");
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
h_A, lda, &h_beta, h_C, N));
|
|
|
|
h_A, lda, &h_beta, h_C, N));
|
|
|
@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
|
|
|
|
"cublas Hgemm requires GPU compute capability >= 53");
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
h_A, lda, &h_beta, h_C, ldc));
|
|
|
|
h_A, lda, &h_beta, h_C, ldc));
|
|
|
@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
const half* h_B = reinterpret_cast<const half*>(B);
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
half* h_C = reinterpret_cast<half*>(C);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(kexinzhao): add processing code for compute capability < 53 case
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
|
|
|
|
|
|
|
|
"cublas Hgemm requires GPU compute capability >= 53");
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
|
|
|
|
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
|
|
|
|
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
|
|
|
|