|
|
|
@ -31,6 +31,26 @@ struct CBlas<float> {
|
|
|
|
|
platform::dynload::cblas_sgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static float *GEMM_ALLOC(ARGS... args) {
|
|
|
|
|
return platform::dynload::cblas_sgemm_alloc(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_PACK(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sgemm_pack(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_COMPUTE(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sgemm_compute(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_FREE(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sgemm_free(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void SMM_GEMM(ARGS... args) {
|
|
|
|
@ -71,6 +91,26 @@ struct CBlas<double> {
|
|
|
|
|
platform::dynload::cblas_dgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static double *GEMM_ALLOC(ARGS... args) {
|
|
|
|
|
return platform::dynload::cblas_dgemm_alloc(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_PACK(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dgemm_pack(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_COMPUTE(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dgemm_compute(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_FREE(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dgemm_free(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void SMM_GEMM(ARGS... args) {
|
|
|
|
@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
|
|
|
|
|
const int M, const int N,
|
|
|
|
|
const int K) const {
|
|
|
|
|
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
|
|
|
|
|
const CBLAS_TRANSPOSE trans,
|
|
|
|
|
int M, int N, int K,
|
|
|
|
|
const T alpha, const T *src,
|
|
|
|
|
const int ld, T *dst) const {
|
|
|
|
|
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
|
|
|
|
|
int transA, int transB, int M, int N, int K, const T *A, const int lda,
|
|
|
|
|
const T *B, const int ldb, T beta, T *C, const int ldc) const {
|
|
|
|
|
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
|
|
|
|
|
CBlas<T>::GEMM_FREE(data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|