|
|
|
@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void VScalRefer(const T* a, const T* x, T* y, int n) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = a[0] * x[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
template <typename T>
|
|
|
|
|
void VMulMKL(const T* x, const T* y, T* z, int n);
|
|
|
|
@ -83,6 +90,28 @@ template <>
|
|
|
|
|
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
|
|
|
|
|
platform::dynload::vdAdd(n, x, y, z);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void VScalMKL(const T* a, const T* x, T* y, int n);
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
|
|
|
|
|
if (x == y) {
|
|
|
|
|
platform::dynload::cblas_sscal(n, *a, y, 1);
|
|
|
|
|
} else {
|
|
|
|
|
VScalRefer<float>(a, x, y, n);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
|
|
|
|
|
if (x == y) {
|
|
|
|
|
platform::dynload::cblas_dscal(n, *a, y, 1);
|
|
|
|
|
} else {
|
|
|
|
|
VScalRefer<double>(a, x, y, n);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define DECLARE_STATIC_FUNC \
|
|
|
|
@ -226,87 +255,60 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#undef DECLARE_STATIC_FUNC
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL(vmul, VMulKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vadd, VAddKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
|
|
|
|
|
|
|
|
|
|
/* VSCAL JitKernel */
|
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
/* VScal JitKernel */
|
|
|
|
|
template <typename T>
|
|
|
|
|
class VScalKernelImpl : public VScalKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
|
|
|
|
|
void Compute(const T a, const T* x, T* y) const override {
|
|
|
|
|
for (int i = 0; i < this->num_; ++i) {
|
|
|
|
|
y[i] = a * x[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void Compute(const T a, T* x) const override {
|
|
|
|
|
for (int i = 0; i < this->num_; ++i) {
|
|
|
|
|
x[i] = a * x[i];
|
|
|
|
|
DECLARE_STATIC_FUNC;
|
|
|
|
|
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
if (useJIT(d)) {
|
|
|
|
|
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
|
|
|
|
|
jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096));
|
|
|
|
|
this->Compute =
|
|
|
|
|
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#define MKL_FLOAT(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
|
|
|
|
|
const { \
|
|
|
|
|
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define MKL_DOUBLE(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
|
|
|
|
|
const { \
|
|
|
|
|
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FOR_EACH_ISA(MKL_FLOAT, kGT16);
|
|
|
|
|
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
|
|
|
|
|
if (useMKL(d)) {
|
|
|
|
|
this->Compute = VScalMKL<T>;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
|
|
|
|
|
const float a, const float* x, float* y) const { \
|
|
|
|
|
__m256 tmp; \
|
|
|
|
|
__m256 scalar = _mm256_set1_ps(a); \
|
|
|
|
|
tmp = _mm256_loadu_ps(x); \
|
|
|
|
|
tmp = _mm256_mul_ps(tmp, scalar); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp); \
|
|
|
|
|
}
|
|
|
|
|
#define INTRI8_INPLACE_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
|
|
|
|
|
const { \
|
|
|
|
|
__m256 tmp; \
|
|
|
|
|
__m256 scalar = _mm256_set1_ps(a); \
|
|
|
|
|
tmp = _mm256_loadu_ps(x); \
|
|
|
|
|
tmp = _mm256_mul_ps(tmp, scalar); \
|
|
|
|
|
_mm256_storeu_ps(x, tmp); \
|
|
|
|
|
this->Compute = VScalRefer<T>;
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
INTRI8_FLOAT(jit::avx);
|
|
|
|
|
INTRI8_INPLACE_FLOAT(jit::avx);
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<gen::VScalJitCode> jitcode_{nullptr};
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRI8_FLOAT(jit::avx2);
|
|
|
|
|
INTRI8_INPLACE_FLOAT(jit::avx2);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
template <>
|
|
|
|
|
bool VScalKernelImpl<float>::useJIT(int d) {
|
|
|
|
|
return gen::VScalJitCode::init(d);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
|
INTRI8_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI8_INPLACE_FLOAT(jit::avx512f);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
template <>
|
|
|
|
|
bool VScalKernelImpl<float>::useMKL(int d) {
|
|
|
|
|
return d > 512;
|
|
|
|
|
}
|
|
|
|
|
template <>
|
|
|
|
|
bool VScalKernelImpl<double>::useMKL(int d) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
// TODO(TJ): eq16 test and complete avx512
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef INTRI8_INPLACE_FLOAT
|
|
|
|
|
#undef MKL_FLOAT
|
|
|
|
|
#undef MKL_DOUBLE
|
|
|
|
|
#undef DECLARE_STATIC_FUNC
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL(vmul, VMulKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vadd, VAddKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vscal, VScalKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
|
|
|
|
|
|
|
|
|
|
/* VAddBias JitKernel */
|
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
|
|
|
|
|
void Compute(const T* x, T* y) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
|
|
|
|
|
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
|
|
|
|
|
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
|
|
|
|
|
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
|
|
|
|
|