|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
@ -102,6 +103,16 @@ struct CBlas<float> {
|
|
|
|
|
static void VEXP(ARGS... args) {
|
|
|
|
|
platform::dynload::vsExp(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VSQR(ARGS... args) {
|
|
|
|
|
platform::dynload::vsSqr(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VPOW(ARGS... args) {
|
|
|
|
|
platform::dynload::vsPowx(args...);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -182,6 +193,16 @@ struct CBlas<double> {
|
|
|
|
|
static void VEXP(ARGS... args) {
|
|
|
|
|
platform::dynload::vdExp(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VSQR(ARGS... args) {
|
|
|
|
|
platform::dynload::vdSqr(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void VPOW(ARGS... args) {
|
|
|
|
|
platform::dynload::vdPowx(args...);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
@ -241,6 +262,8 @@ struct CBlas<platform::float16> {
|
|
|
|
|
}
|
|
|
|
|
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
|
|
|
|
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
|
|
|
|
|
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
|
|
|
|
|
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
|
|
|
|
|
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
|
|
|
|
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
CBlas<T>::VSQR(n, x, y);
|
|
|
|
|
#else
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = std::sqrt(x[i]);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
|
|
|
|
|
T *y) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
CBlas<T>::VPOW(n, x, a, y);
|
|
|
|
|
#else
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = std::pow(x[i], a);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
template <typename T>
|
|
|
|
|
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
|
|
|
|
|