|
|
|
@ -73,6 +73,11 @@ struct CBlas<float> {
|
|
|
|
|
platform::dynload::cblas_sgemv(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static float DOT(ARGS... args) {
|
|
|
|
|
return platform::dynload::cblas_sdot(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_BATCH(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sgemm_batch(args...);
|
|
|
|
@ -138,6 +143,11 @@ struct CBlas<double> {
|
|
|
|
|
platform::dynload::cblas_dgemv(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static double DOT(ARGS... args) {
|
|
|
|
|
return platform::dynload::cblas_ddot(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_BATCH(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dgemm_batch(args...);
|
|
|
|
@ -210,6 +220,7 @@ struct CBlas<platform::float16> {
|
|
|
|
|
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
|
|
|
|
|
}
|
|
|
|
|
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
|
|
|
|
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
|
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
|
|
|
|
@ -352,6 +363,21 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
return CBlas<T>::DOT(n, x, y);
|
|
|
|
|
#else
|
|
|
|
|
// try to find if openblas support cblas_dot
|
|
|
|
|
T sum = 0;
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
sum += x[i] * y[i];
|
|
|
|
|
}
|
|
|
|
|
return sum;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
|
|
|
|
@ -423,7 +449,6 @@ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
|
|
|
|
|
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
|
|
|
|
|
C, &N);
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
|
|
|
|
|