|
|
|
|
@ -74,15 +74,22 @@ namespace jit = platform::jit;
|
|
|
|
|
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
|
|
|
|
|
FOR_EACH_ALL_BLOCK(macro_, jit::any)
|
|
|
|
|
|
|
|
|
|
/* VMUL JitKernel */
|
|
|
|
|
#define VMUL_ANY \
|
|
|
|
|
for (int i = 0; i < n; ++i) { \
|
|
|
|
|
z[i] = x[i] * y[i]; \
|
|
|
|
|
#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \
|
|
|
|
|
template <> \
|
|
|
|
|
ker_class<ker_dtype>::ker_class(int d) { \
|
|
|
|
|
SEARCH_ISA_BLOCK(ker_func, ker_dtype); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define BIND_KERNEL(ker_class, ker_func) \
|
|
|
|
|
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \
|
|
|
|
|
BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double)
|
|
|
|
|
|
|
|
|
|
/* VMUL JitKernel */
|
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
static void VMulCompute(const int n, const T* x, const T* y, T* z) {
|
|
|
|
|
VMUL_ANY
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
z[i] = x[i] * y[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
@ -107,6 +114,8 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE)
|
|
|
|
|
/// lt8
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
VMUL_MKL_FLOAT(jit::avx, kLT8)
|
|
|
|
|
VMUL_MKL_FLOAT(jit::avx2, kLT8)
|
|
|
|
|
VMUL_MKL_FLOAT(jit::avx512f, kLT8)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
/// eq8
|
|
|
|
|
@ -143,20 +152,93 @@ VMUL_MKL_FLOAT(jit::avx2, kEQ16)
|
|
|
|
|
VMUL_MKL_FLOAT(jit::avx512f, kEQ16)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define USE_VMUL_KERNEL(T, func) \
|
|
|
|
|
template <> \
|
|
|
|
|
VMulKernel<T>::VMulKernel(int d) { \
|
|
|
|
|
SEARCH_ISA_BLOCK(func, T); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
USE_VMUL_KERNEL(float, VMulCompute);
|
|
|
|
|
USE_VMUL_KERNEL(double, VMulCompute);
|
|
|
|
|
|
|
|
|
|
#undef VMUL_ANY
|
|
|
|
|
#undef VMUL_INTRI8_FLOAT
|
|
|
|
|
#undef VMUL_MKL_FLOAT
|
|
|
|
|
#undef VMUL_MKL_DOUBLE
|
|
|
|
|
#undef USE_VMUL_KERNEL
|
|
|
|
|
|
|
|
|
|
/* VADD */
|
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
static void VAddCompute(const int n, const T* x, const T* y, T* z) {
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
z[i] = x[i] + y[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
#define VADD_MKL_FLOAT(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VAddCompute<float, isa, block>(const int n, const float* x, \
|
|
|
|
|
const float* y, float* z) { \
|
|
|
|
|
platform::dynload::vsAdd(n, x, y, z); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define VADD_MKL_DOUBLE(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VAddCompute<double, isa, block>(const int n, const double* x, \
|
|
|
|
|
const double* y, float* z) { \
|
|
|
|
|
platform::dynload::vdAdd(n, x, y, z); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT)
|
|
|
|
|
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
/// lt8
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx, kLT8)
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx2, kLT8)
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx512f, kLT8)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
/// eq8
|
|
|
|
|
#define VADD_INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VAddCompute<float, isa, kEQ8>(const int n, const float* x, \
|
|
|
|
|
const float* y, float* z) { \
|
|
|
|
|
__m256 tmpx, tmpy; \
|
|
|
|
|
tmpx = _mm256_loadu_ps(x); \
|
|
|
|
|
tmpy = _mm256_loadu_ps(y); \
|
|
|
|
|
tmpx = _mm256_add_ps(tmpx, tmpy); \
|
|
|
|
|
_mm256_storeu_ps(z, tmpx); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// mkl > avx > for, ">" means better
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx, kEQ8)
|
|
|
|
|
#elif defined __AVX__
|
|
|
|
|
VADD_INTRI8_FLOAT(jit::avx)
|
|
|
|
|
#endif
|
|
|
|
|
// avx2 > mkl > for
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
VADD_INTRI8_FLOAT(jit::avx2)
|
|
|
|
|
#elif defined PADDLE_USE_MKLML
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx2, kEQ8)
|
|
|
|
|
#endif
|
|
|
|
|
// TODO(TJ): test and complete avx512
|
|
|
|
|
|
|
|
|
|
/// eq16
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
// TODO(TJ): test and complete me
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx, kEQ16)
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx2, kEQ16)
|
|
|
|
|
VADD_MKL_FLOAT(jit::avx512f, kEQ16)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#undef VADD_INTRI8_FLOAT
|
|
|
|
|
#undef VADD_MKL_FLOAT
|
|
|
|
|
#undef VADD_MKL_DOUBLE
|
|
|
|
|
|
|
|
|
|
BIND_KERNEL(VMulKernel, VMulCompute);
|
|
|
|
|
BIND_KERNEL(VAddKernel, VAddCompute);
|
|
|
|
|
|
|
|
|
|
#undef BIND_KERNEL
|
|
|
|
|
#undef BIND_KERNEL_WITH_DTYPE
|
|
|
|
|
#undef FOR_EACH_ISA_ALL_BLOCK
|
|
|
|
|
#undef FOR_EACH_ALL_BLOCK
|
|
|
|
|
#undef FOR_EACH_ISA_COMMON_BLOCK
|
|
|
|
|
#undef FOR_EACH_COMMON_BLOCK
|
|
|
|
|
#undef SEARCH_ISA_BLOCK
|
|
|
|
|
#undef SEARCH_BLOCK
|
|
|
|
|
|
|
|
|
|
} // namespace jitkernel
|
|
|
|
|
} // namespace math
|
|
|
|
|
|