|
|
|
@ -217,64 +217,6 @@ struct CBlas<platform::float16> {
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
|
|
|
|
|
bool transb, const T &alpha, const T &beta) {
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// But the threshold is custom
|
|
|
|
|
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
|
|
|
|
|
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
|
|
|
|
|
std::abs<T>(alpha - static_cast<T>(1) >
|
|
|
|
|
std::numeric_limits<T>::epsilon()) ||
|
|
|
|
|
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
|
|
|
|
|
return false;
|
|
|
|
|
} else {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
|
|
|
|
|
bool transa, bool transb,
|
|
|
|
|
const platform::float16 &alpha,
|
|
|
|
|
const platform::float16 &beta) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
|
|
|
|
|
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
|
|
|
|
|
const T *A, int lda, const T *B, int ldb, T beta, T *C,
|
|
|
|
|
int ldc) {
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
|
|
|
|
|
beta)) {
|
|
|
|
|
// Note: SMM use ColMajor
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|
const char transb = 'N';
|
|
|
|
|
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
|
|
|
|
|
&beta, C, &ldc);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_MKL_SPLIT_GEMM
|
|
|
|
|
constexpr int bs = 2;
|
|
|
|
|
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
|
|
|
|
|
for (int off = 0; off < M; off += bs) {
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
|
|
|
|
|
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
@ -319,8 +261,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
|
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -329,9 +271,20 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
|
|
|
|
|
int N, int K, T alpha, const T *A,
|
|
|
|
|
int lda, const T *B, int ldb,
|
|
|
|
|
T beta, T *C, int ldc) const {
|
|
|
|
|
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
|
|
|
|
|
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
|
|
|
|
|
lda, B, ldb, beta, C, ldc);
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
|
|
|
|
|
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
|
|
|
|
|
lda, B, ldb, beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|
CBLAS_TRANSPOSE transB, int M,
|
|
|
|
|
int N, int K, T alpha, const T *A,
|
|
|
|
|
int lda, const T *B, int ldb,
|
|
|
|
|
T beta, T *C, int ldc) const {
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext>
|
|
|
|
@ -440,6 +393,43 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
|
|
|
|
|
const T *A, const T *B, T *C) const {
|
|
|
|
|
this->template GEMM<T>(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
|
|
|
|
|
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C,
|
|
|
|
|
N);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
|
|
|
|
|
const int K, const T *A,
|
|
|
|
|
const T *B, T *C) const {
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
|
|
|
|
|
|
|
|
|
|
// Since the matrix is very small,
|
|
|
|
|
// so the unit of calculation is already very fast,
|
|
|
|
|
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
|
|
|
|
|
// use xsmm directly.
|
|
|
|
|
// Note: SMM use ColMajor
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|
const char transb = 'N';
|
|
|
|
|
const T alpha = static_cast<T>(1);
|
|
|
|
|
const T beta = static_cast<T>(0);
|
|
|
|
|
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
|
|
|
|
|
C, &N);
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
|
|
|
|
|
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
|
|
|
|
|