|
|
|
@ -19,74 +19,29 @@ namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB,
|
|
|
|
|
const int M,
|
|
|
|
|
const int N,
|
|
|
|
|
const int K,
|
|
|
|
|
const float alpha,
|
|
|
|
|
const float* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
const float* B,
|
|
|
|
|
const int ldb,
|
|
|
|
|
const float beta,
|
|
|
|
|
float* C,
|
|
|
|
|
const int ldc,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
cblas_sgemm(CblasRowMajor,
|
|
|
|
|
transA,
|
|
|
|
|
transB,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
alpha,
|
|
|
|
|
A,
|
|
|
|
|
lda,
|
|
|
|
|
B,
|
|
|
|
|
ldb,
|
|
|
|
|
beta,
|
|
|
|
|
C,
|
|
|
|
|
ldc);
|
|
|
|
|
void gemm<platform::CPUPlace, float>(
|
|
|
|
|
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
|
|
|
|
|
const int N, const int K, const float alpha, const float* A, const int lda,
|
|
|
|
|
const float* B, const int ldb, const float beta, float* C, const int ldc,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB,
|
|
|
|
|
const int M,
|
|
|
|
|
const int N,
|
|
|
|
|
const int K,
|
|
|
|
|
const double alpha,
|
|
|
|
|
const double* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
const double* B,
|
|
|
|
|
const int ldb,
|
|
|
|
|
const double beta,
|
|
|
|
|
double* C,
|
|
|
|
|
const int ldc,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
cblas_dgemm(CblasRowMajor,
|
|
|
|
|
transA,
|
|
|
|
|
transB,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
alpha,
|
|
|
|
|
A,
|
|
|
|
|
lda,
|
|
|
|
|
B,
|
|
|
|
|
ldb,
|
|
|
|
|
beta,
|
|
|
|
|
C,
|
|
|
|
|
ldc);
|
|
|
|
|
void gemm<platform::CPUPlace, double>(
|
|
|
|
|
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
|
|
|
|
|
const int N, const int K, const double alpha, const double* A,
|
|
|
|
|
const int lda, const double* B, const int ldb, const double beta, double* C,
|
|
|
|
|
const int ldc, platform::DeviceContext* context) {
|
|
|
|
|
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1,
|
|
|
|
|
bool in1_T,
|
|
|
|
|
const framework::Tensor& in2,
|
|
|
|
|
bool in2_T,
|
|
|
|
|
float alpha,
|
|
|
|
|
framework::Tensor* out,
|
|
|
|
|
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
|
|
|
|
|
const framework::Tensor& in2, bool in2_T,
|
|
|
|
|
float alpha, framework::Tensor* out,
|
|
|
|
|
float beta,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
auto in1_dim = in1.dims();
|
|
|
|
@ -111,30 +66,17 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1,
|
|
|
|
|
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
|
|
|
|
|
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
|
|
|
|
|
|
|
|
|
|
gemm<platform::CPUPlace, float>(in1_Trans,
|
|
|
|
|
in2_Trans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
alpha,
|
|
|
|
|
in1.data<float>(),
|
|
|
|
|
K,
|
|
|
|
|
in2.data<float>(),
|
|
|
|
|
N,
|
|
|
|
|
beta,
|
|
|
|
|
out->data<float>(),
|
|
|
|
|
N,
|
|
|
|
|
context);
|
|
|
|
|
gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
|
|
|
|
|
in1.data<float>(), K, in2.data<float>(), N,
|
|
|
|
|
beta, out->data<float>(), N, context);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
|
|
|
|
|
void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
|
|
|
|
|
bool in1_T,
|
|
|
|
|
const framework::Tensor& in2,
|
|
|
|
|
bool in2_T,
|
|
|
|
|
float alpha,
|
|
|
|
|
framework::Tensor* out,
|
|
|
|
|
float beta,
|
|
|
|
|
bool in2_T, float alpha,
|
|
|
|
|
framework::Tensor* out, float beta,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
auto in1_dim = in1.dims();
|
|
|
|
|
auto in2_dim = in2.dims();
|
|
|
|
@ -157,20 +99,9 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
|
|
|
|
|
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
|
|
|
|
|
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
|
|
|
|
|
|
|
|
|
|
gemm<platform::CPUPlace, double>(in1_Trans,
|
|
|
|
|
in2_Trans,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
alpha,
|
|
|
|
|
in1.data<double>(),
|
|
|
|
|
K,
|
|
|
|
|
in2.data<double>(),
|
|
|
|
|
N,
|
|
|
|
|
beta,
|
|
|
|
|
out->data<double>(),
|
|
|
|
|
N,
|
|
|
|
|
context);
|
|
|
|
|
gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
|
|
|
|
|
in1.data<double>(), K, in2.data<double>(), N,
|
|
|
|
|
beta, out->data<double>(), N, context);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|