|
|
|
|
@ -82,6 +82,11 @@ struct CBlas<float> {
|
|
|
|
|
static void VADD(ARGS... args) {
|
|
|
|
|
platform::dynload::vsAdd(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VMUL(ARGS... args) {
|
|
|
|
|
platform::dynload::vsMul(args...);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
@ -142,6 +147,11 @@ struct CBlas<double> {
|
|
|
|
|
static void VADD(ARGS... args) {
|
|
|
|
|
platform::dynload::vdAdd(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VMUL(ARGS... args) {
|
|
|
|
|
platform::dynload::vdMul(args...);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
@ -199,6 +209,7 @@ struct CBlas<platform::float16> {
|
|
|
|
|
static void SMM_GEMM(...) {
|
|
|
|
|
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
|
|
|
|
|
}
|
|
|
|
|
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
|
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
|
|
|
|
|
@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
|
|
|
|
|
T *z) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
CBlas<T>::VMUL(n, x, y, z);
|
|
|
|
|
#else
|
|
|
|
|
// try to find if openblas support vmul
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
z[i] = x[i] * y[i];
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
|
|
|
|
|
|