|
|
|
@ -30,6 +30,12 @@ struct CBlas<float> {
|
|
|
|
|
platform::dynload::cblas_sgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void SMM_GEMM(ARGS... args) {
|
|
|
|
|
libxsmm_sgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void AXPY(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_saxpy(args...);
|
|
|
|
@ -63,6 +69,12 @@ struct CBlas<double> {
|
|
|
|
|
platform::dynload::cblas_dgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void SMM_GEMM(ARGS... args) {
|
|
|
|
|
libxsmm_dgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void AXPY(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_daxpy(args...);
|
|
|
|
@ -140,6 +152,9 @@ struct CBlas<double> {
|
|
|
|
|
template <>
|
|
|
|
|
struct CBlas<platform::float16> {
|
|
|
|
|
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
|
|
|
|
|
static void SMM_GEMM(...) {
|
|
|
|
|
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
|
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
|
|
|
|
@ -153,11 +168,28 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|
CBLAS_TRANSPOSE transB, int M,
|
|
|
|
|
int N, int K, T alpha, const T *A,
|
|
|
|
|
const T *B, T beta, T *C) const {
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans &&
|
|
|
|
|
transB == CblasNoTrans) {
|
|
|
|
|
// refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// Note: SMM use ColMajor
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|
const char transb = 'N';
|
|
|
|
|
const int lda = M;
|
|
|
|
|
const int ldb = K;
|
|
|
|
|
const int ldc = M;
|
|
|
|
|
CBlas<T>::SMM_GEMM(&transa, &transb, &M, &N, &K, &alpha, A, &lda, B, &ldb,
|
|
|
|
|
&beta, C, &ldc);
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
|
|
|
|
|
ldb, beta, C, ldc);
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|