|
|
|
@ -105,7 +105,7 @@ struct CBlas<float> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VSQR(ARGS... args) {
|
|
|
|
|
static void VSQUARE(ARGS... args) {
|
|
|
|
|
platform::dynload::vsSqr(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -195,7 +195,7 @@ struct CBlas<double> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VSQR(ARGS... args) {
|
|
|
|
|
static void VSQUARE(ARGS... args) {
|
|
|
|
|
platform::dynload::vdSqr(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -262,7 +262,9 @@ struct CBlas<platform::float16> {
|
|
|
|
|
}
|
|
|
|
|
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
|
|
|
|
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
|
|
|
|
|
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
|
|
|
|
|
static void VSQUARE(...) {
|
|
|
|
|
PADDLE_THROW("float16 VSQUARE not supported on CPU");
|
|
|
|
|
}
|
|
|
|
|
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
|
|
|
|
|
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
|
|
|
|
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
|
|
|
|
@ -423,12 +425,12 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::VSQUARE(int n, const T *x, T *y) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
CBlas<T>::VSQR(n, x, y);
|
|
|
|
|
CBlas<T>::VSQUARE(n, x, y);
|
|
|
|
|
#else
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = std::sqrt(x[i]);
|
|
|
|
|
y[i] = x[i] * x[i];
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|