|
|
|
@ -62,19 +62,27 @@ struct CUBlas<float> {
|
|
|
|
|
cudaDataType_t Atype, int lda, const void *B,
|
|
|
|
|
cudaDataType_t Btype, int ldb, const float *beta, void *C,
|
|
|
|
|
cudaDataType_t Ctype, int ldc) {
|
|
|
|
|
// Because the gcc 4.8 doesn't expand template parameter pack that
|
|
|
|
|
// appears in a lambda-expression, I can not use template parameter pack
|
|
|
|
|
// here.
|
|
|
|
|
// Because the gcc 4.8 doesn't expand template parameter pack that
|
|
|
|
|
// appears in a lambda-expression, I can not use template parameter pack
|
|
|
|
|
// here.
|
|
|
|
|
auto cublas_call = [&]() {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (dev_ctx->tensor_core_available() ? "True" : "False");
|
|
|
|
|
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (platform::TensorCoreAvailable() ? "True" : "False");
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
|
|
|
|
|
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
|
|
|
|
|
beta, C, Ctype, ldc));
|
|
|
|
|
});
|
|
|
|
|
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
|
|
|
|
|
lda, B, Btype, ldb, beta, C, Ctype, ldc));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
|
|
|
|
|
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
// NOTES: To use Tensor Core, we should change the cublas config,
|
|
|
|
|
// but the cublas may be hold by multi-thread.
|
|
|
|
|
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
|
|
|
|
|
#else
|
|
|
|
|
cublas_call();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -162,24 +170,32 @@ struct CUBlas<platform::float16> {
|
|
|
|
|
cudaDataType_t Btype, int ldb, const void *beta, void *C,
|
|
|
|
|
cudaDataType_t Ctype, int ldc,
|
|
|
|
|
cudaDataType_t computeType) {
|
|
|
|
|
auto cublas_call = [&]() {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
bool use_tensor_op_math = dev_ctx->tensor_core_available();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
bool use_tensor_op_math = platform::TensorCoreAvailable();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
#endif // CUDA_VERSION >= 9000
|
|
|
|
|
|
|
|
|
|
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
|
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
|
|
|
|
|
beta, C, Ctype, ldc, computeType, algo));
|
|
|
|
|
});
|
|
|
|
|
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
|
|
|
|
|
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
|
|
|
|
|
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
// NOTES: To use Tensor Core, we should change the cublas config,
|
|
|
|
|
// but the cublas may be hold by multi-thread.
|
|
|
|
|
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
|
|
|
|
|
#else
|
|
|
|
|
cublas_call();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -207,10 +223,9 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|
CUDA_R_32F, N);
|
|
|
|
|
} else {
|
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
|
|
|
|
|
lda, &beta, C, N);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
|
|
|
|
|
&alpha, B, ldb, A, lda, &beta, C, N);
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
}
|
|
|
|
@ -251,12 +266,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
|
|
|
|
|
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
|
|
|
|
|
#else
|
|
|
|
|
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
|
|
|
|
|
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
|
|
|
|
|
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
|
|
|
|
|
N);
|
|
|
|
|
});
|
|
|
|
|
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
|
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
|
|
|
|
|
&h_beta, h_C, N);
|
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -280,10 +292,8 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
|
|
|
|
|
} else {
|
|
|
|
|
#endif // CUDA_VERSION >= 8000
|
|
|
|
|
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
|
|
|
|
|
lda, &beta, C, ldc);
|
|
|
|
|
});
|
|
|
|
|
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
|
|
|
|
|
&alpha, B, ldb, A, lda, &beta, C, ldc);
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
}
|
|
|
|
@ -301,19 +311,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
|
|
|
|
|
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
|
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
|
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
|
|
|
|
|
B, ldb, A, lda, &beta, C, ldc);
|
|
|
|
|
});
|
|
|
|
|
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
|
N, M, K, &alpha, B, ldb, A, lda, &beta, C,
|
|
|
|
|
ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
|
|
|
|
|
T *y) const {
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
|
|
|
|
|
});
|
|
|
|
|
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -323,9 +330,8 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
|
|
|
|
|
T beta, T *C) const {
|
|
|
|
|
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
|
|
|
|
|
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
|
|
|
|
|
});
|
|
|
|
|
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
|
|
|
|
|
&beta, C, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -347,28 +353,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9010
|
|
|
|
|
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
bool use_tensor_op_math = context_.tensor_core_available();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
|
|
|
|
|
auto cublas_call = [&]() {
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
bool use_tensor_op_math = platform::TensorCoreAvailable();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
|
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
|
|
|
|
|
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
|
|
|
|
|
strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
});
|
|
|
|
|
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
|
|
|
|
|
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
|
|
|
|
|
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
};
|
|
|
|
|
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
|
|
|
|
|
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
|
|
|
|
|
} else {
|
|
|
|
|
#endif // CUDA_VERSION >= 9010
|
|
|
|
|
|
|
|
|
|
context_.CublasCall([&](cublasHandle_t handle) {
|
|
|
|
|
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
|
|
|
|
|
B, ldb, strideB, A, lda, strideA, &beta, C,
|
|
|
|
|
ldc, strideC, batchCount);
|
|
|
|
|
});
|
|
|
|
|
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
|
|
|
|
|
N, M, K, &alpha, B, ldb, strideB, A, lda,
|
|
|
|
|
strideA, &beta, C, ldc, strideC, batchCount);
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9010
|
|
|
|
|
}
|
|
|
|
|