|
|
|
@ -78,6 +78,11 @@ struct CBlas<float> {
|
|
|
|
|
return platform::dynload::cblas_sdot(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void SCAL(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sscal(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_BATCH(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sgemm_batch(args...);
|
|
|
|
@ -148,6 +153,11 @@ struct CBlas<double> {
|
|
|
|
|
return platform::dynload::cblas_ddot(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void SCAL(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dscal(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_BATCH(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dgemm_batch(args...);
|
|
|
|
@ -221,6 +231,7 @@ struct CBlas<platform::float16> {
|
|
|
|
|
}
|
|
|
|
|
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
|
|
|
|
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
|
|
|
|
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
|
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
|
|
|
|
@ -367,7 +378,7 @@ template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
return CBlas<T>::DOT(n, x, y);
|
|
|
|
|
return CBlas<T>::DOT(n, x, 1, y, 1);
|
|
|
|
|
#else
|
|
|
|
|
// try to find if openblas support cblas_dot
|
|
|
|
|
T sum = 0;
|
|
|
|
@ -378,6 +389,20 @@ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a,
|
|
|
|
|
const T *x) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
CBlas<T>::SCAL(n, a, x, 1);
|
|
|
|
|
#else
|
|
|
|
|
// try to find if openblas support cblas_scal
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
x[i] = a * x[i];
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
|
|
|
|
|