|
|
|
@ -84,6 +84,11 @@ struct CBlas<float> {
|
|
|
|
|
platform::dynload::cblas_sscal(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static float ASUM(ARGS... args) {
|
|
|
|
|
return platform::dynload::cblas_sasum(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_BATCH(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_sgemm_batch(args...);
|
|
|
|
@ -174,6 +179,11 @@ struct CBlas<double> {
|
|
|
|
|
platform::dynload::cblas_dscal(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static double ASUM(ARGS... args) {
|
|
|
|
|
return platform::dynload::cblas_dasum(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM_BATCH(ARGS... args) {
|
|
|
|
|
platform::dynload::cblas_dgemm_batch(args...);
|
|
|
|
@ -268,6 +278,7 @@ struct CBlas<platform::float16> {
|
|
|
|
|
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
|
|
|
|
|
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
|
|
|
|
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
|
|
|
|
|
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
|
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
|
|
|
|
@ -476,6 +487,23 @@ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
T Blas<platform::CPUDeviceContext>::ASUM(int n, T *x, int inc) const {
|
|
|
|
|
auto sum = static_cast<T>(0.0);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
sum = Blas<T>::ASUM(n, x, inc);
|
|
|
|
|
#else
|
|
|
|
|
//TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum
|
|
|
|
|
for (int c = 0; c < n; ++c) {
|
|
|
|
|
sum += x[c];
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return sum;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
|
|
|
|
|