|
|
|
@ -62,27 +62,17 @@ struct CUBlas<float> {
|
|
|
|
|
cudaDataType_t Atype, int lda, const void *B,
|
|
|
|
|
cudaDataType_t Btype, int ldb, const float *beta, void *C,
|
|
|
|
|
cudaDataType_t Ctype, int ldc) {
|
|
|
|
|
// Because the gcc 4.8 doesn't expand template parameter pack that
|
|
|
|
|
// appears in a lambda-expression, I can not use template parameter pack
|
|
|
|
|
// here.
|
|
|
|
|
auto cublas_call = [&]() {
|
|
|
|
|
// Because the gcc 4.8 doesn't expand template parameter pack that
|
|
|
|
|
// appears in a lambda-expression, I can not use template parameter pack
|
|
|
|
|
// here.
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (platform::TensorCoreAvailable() ? "True" : "False");
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
|
|
|
|
|
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
|
|
|
|
|
lda, B, Btype, ldb, beta, C, Ctype, ldc));
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (dev_ctx->tensor_core_available() ? "True" : "False");
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
|
|
|
|
|
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
|
|
|
|
|
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
// NOTES: To use Tensor Core, we should change the cublas config,
|
|
|
|
|
// but the cublas may be hold by multi-thread.
|
|
|
|
|
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
|
|
|
|
|
#else
|
|
|
|
|
cublas_call();
|
|
|
|
|
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -170,32 +160,23 @@ struct CUBlas<platform::float16> {
|
|
|
|
|
cudaDataType_t Btype, int ldb, const void *beta, void *C,
|
|
|
|
|
cudaDataType_t Ctype, int ldc,
|
|
|
|
|
cudaDataType_t computeType) {
|
|
|
|
|
auto cublas_call = [&]() {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
bool use_tensor_op_math = platform::TensorCoreAvailable();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
bool use_tensor_op_math = dev_ctx->tensor_core_available();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
#endif // CUDA_VERSION >= 9000
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
|
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
|
|
|
|
|
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
// NOTES: To use Tensor Core, we should change the cublas config,
|
|
|
|
|
// but the cublas may be hold by multi-thread.
|
|
|
|
|
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
|
|
|
|
|
dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k,
|
|
|
|
|
alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType,
|
|
|
|
|
algo));
|
|
|
|
|
#else
|
|
|
|
|
cublas_call();
|
|
|
|
|
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -353,22 +334,18 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9010
|
|
|
|
|
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
|
|
|
|
|
auto cublas_call = [&]() {
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
bool use_tensor_op_math = platform::TensorCoreAvailable();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
|
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
|
|
|
|
|
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
|
|
|
|
|
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
};
|
|
|
|
|
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
|
|
|
|
|
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
|
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
|
|
|
|
|
bool use_tensor_op_math = context_.tensor_core_available();
|
|
|
|
|
if (use_tensor_op_math) {
|
|
|
|
|
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "use_tensor_op_math: "
|
|
|
|
|
<< (use_tensor_op_math ? "True" : "False");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
|
|
|
|
|
context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M,
|
|
|
|
|
K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA,
|
|
|
|
|
&beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
|
|
|
|
|
} else {
|
|
|
|
|
#endif // CUDA_VERSION >= 9010
|
|
|
|
|
|
|
|
|
|