diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 701965759e..238bd3f8de 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -162,10 +162,10 @@ struct CBlas { } #endif }; + template -inline static bool UseXSMM(const int &m, const int &n, const int &k, - bool transa, bool transb, const T &alpha, - const T &beta) { +inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa, + bool transb, const T &alpha, const T &beta) { #ifdef PADDLE_WITH_LIBXSMM // Refer to https://github.com/hfp/libxsmm/blob/master/README.md // But the threshold is custom @@ -182,6 +182,14 @@ inline static bool UseXSMM(const int &m, const int &n, const int &k, return false; } +template <> +inline bool UseXSMM(const int &m, const int &n, const int &k, + bool transa, bool transb, + const platform::float16 &alpha, + const platform::float16 &beta) { + return false; +} + template <> template void Blas::GEMM(CBLAS_TRANSPOSE transA, @@ -194,7 +202,6 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, #ifdef PADDLE_WITH_LIBXSMM if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, beta)) { - // refer to https://github.com/hfp/libxsmm/blob/master/README.md // Note: SMM use ColMajor const char transa = 'N'; const char transb = 'N';