|
|
|
@ -162,10 +162,10 @@ struct CBlas<platform::float16> {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline static bool UseXSMM(const int &m, const int &n, const int &k,
|
|
|
|
|
bool transa, bool transb, const T &alpha,
|
|
|
|
|
const T &beta) {
|
|
|
|
|
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
|
|
|
|
|
bool transb, const T &alpha, const T &beta) {
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// But the threshold is custom
|
|
|
|
@ -182,6 +182,14 @@ inline static bool UseXSMM(const int &m, const int &n, const int &k,
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
|
|
|
|
|
bool transa, bool transb,
|
|
|
|
|
const platform::float16 &alpha,
|
|
|
|
|
const platform::float16 &beta) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
@ -194,7 +202,6 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
|
|
|
|
|
beta)) {
|
|
|
|
|
// refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// Note: SMM use ColMajor
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|
const char transb = 'N';
|
|
|
|
|