|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#define EIGEN_USE_GPU
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function_impl.h"
|
|
|
|
@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float16 alpha, const float16* A, const float16* B, const float16 beta,
|
|
|
|
|
float16* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
float16* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
|
|
|
|
|
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
const int64_t strideC = M * N;
|
|
|
|
|
|
|
|
|
|
const half h_alpha = static_cast<const half>(alpha);
|
|
|
|
|
const half h_beta = static_cast<const half>(beta);
|
|
|
|
@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
|
|
|
|
|
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float alpha, const float* A, const float* B, const float beta,
|
|
|
|
|
float* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
float* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
|
|
|
|
|
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
const int64_t strideC = M * N;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
|
|
|
|
@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
|
|
|
|
|
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const double alpha, const double* A, const double* B, const double beta,
|
|
|
|
|
double* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
double* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
#if CUDA_VERSION >= 8000
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
|
|
|
|
|
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
const int strideC = M * N;
|
|
|
|
|
const int64_t strideC = M * N;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
|
|
|
|
|
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
|
|
|
|
|