|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
@ -161,6 +162,25 @@ struct CBlas<platform::float16> {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline static bool UseXSMM(const int &m, const int &n, const int &k,
|
|
|
|
|
bool transa, bool transb, const T &alpha,
|
|
|
|
|
const T &beta) {
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// But the threshold is custom
|
|
|
|
|
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
|
|
|
|
|
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
|
|
|
|
|
std::abs<T>(alpha - static_cast<T>(1) >
|
|
|
|
|
std::numeric_limits<T>::epsilon()) ||
|
|
|
|
|
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
|
|
|
|
|
return false;
|
|
|
|
|
} else {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
@ -172,8 +192,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
|
#ifdef PADDLE_WITH_LIBXSMM
|
|
|
|
|
if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans &&
|
|
|
|
|
transB == CblasNoTrans) {
|
|
|
|
|
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
|
|
|
|
|
beta)) {
|
|
|
|
|
// refer to https://github.com/hfp/libxsmm/blob/master/README.md
|
|
|
|
|
// Note: SMM use ColMajor
|
|
|
|
|
const char transa = 'N';
|
|
|
|
|