|
|
|
@ -19,12 +19,13 @@ namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
void gemm<platform::GPUPlace, float>(const platform::DeviceContext& context,
|
|
|
|
|
const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M,
|
|
|
|
|
const int N, const int K,
|
|
|
|
|
const float alpha, const float* A,
|
|
|
|
|
const float* B, const float beta, float* C,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
const float* B, const float beta,
|
|
|
|
|
float* C) {
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
@ -35,18 +36,19 @@ void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
|
|
|
|
|
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.cublas_handle(),
|
|
|
|
|
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
|
|
|
|
|
const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M,
|
|
|
|
|
const int N, const int K,
|
|
|
|
|
const double alpha, const double* A,
|
|
|
|
|
const double* B, const double beta,
|
|
|
|
|
double* C,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
double* C) {
|
|
|
|
|
// Note that cublas follows fortran order, so the order is different from
|
|
|
|
|
// the cblas convention.
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
@ -56,18 +58,16 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
cublasOperation_t cuTransB =
|
|
|
|
|
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
|
|
|
|
|
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.cublas_handle(),
|
|
|
|
|
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
|
|
|
|
|
bool trans_a,
|
|
|
|
|
const framework::Tensor& matrix_b,
|
|
|
|
|
bool trans_b, float alpha,
|
|
|
|
|
framework::Tensor* matrix_out,
|
|
|
|
|
float beta,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
void matmul<platform::GPUPlace, float>(
|
|
|
|
|
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
|
|
|
|
|
bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha,
|
|
|
|
|
framework::Tensor* matrix_out, float beta) {
|
|
|
|
|
auto dim_a = matrix_a.dims();
|
|
|
|
|
auto dim_b = matrix_b.dims();
|
|
|
|
|
auto dim_out = matrix_out->dims();
|
|
|
|
@ -87,18 +87,15 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
|
|
|
|
|
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
|
|
|
|
|
|
|
|
|
|
gemm<platform::GPUPlace, float>(
|
|
|
|
|
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
|
|
|
|
|
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
|
|
|
|
|
context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
|
|
|
|
|
matrix_b.data<float>(), beta, matrix_out->data<float>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
|
|
|
|
|
bool trans_a,
|
|
|
|
|
const framework::Tensor& matrix_b,
|
|
|
|
|
bool trans_b, double alpha,
|
|
|
|
|
framework::Tensor* matrix_out,
|
|
|
|
|
double beta,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
void matmul<platform::GPUPlace, double>(
|
|
|
|
|
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
|
|
|
|
|
bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha,
|
|
|
|
|
framework::Tensor* matrix_out, double beta) {
|
|
|
|
|
auto dim_a = matrix_a.dims();
|
|
|
|
|
auto dim_b = matrix_b.dims();
|
|
|
|
|
auto dim_out = matrix_out->dims();
|
|
|
|
@ -118,8 +115,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
|
|
|
|
|
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
|
|
|
|
|
|
|
|
|
|
gemm<platform::GPUPlace, double>(
|
|
|
|
|
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
|
|
|
|
|
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
|
|
|
|
|
context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
|
|
|
|
|
matrix_b.data<double>(), beta, matrix_out->data<double>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|