|
|
|
@ -75,25 +75,24 @@ namespace jit = platform::jit;
|
|
|
|
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
|
|
|
|
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
|
|
|
|
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
|
|
|
|
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
|
|
|
|
|
|
|
|
|
|
|
|
// do not include lt8, eq8, eq16
|
|
|
|
#define FOR_EACH_ISA(macro_, block) \
|
|
|
|
#define FOR_EACH_COMMON_BLOCK(macro_, isa) \
|
|
|
|
macro_(jit::avx512f, block); \
|
|
|
|
macro_(isa, kGT8LT16) macro_(isa, kGT16)
|
|
|
|
macro_(jit::avx2, block); \
|
|
|
|
|
|
|
|
macro_(jit::avx, block); \
|
|
|
|
#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \
|
|
|
|
macro_(jit::isa_any, block)
|
|
|
|
FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \
|
|
|
|
|
|
|
|
FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \
|
|
|
|
#define FOR_EACH_BLOCK(macro_, isa) \
|
|
|
|
FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \
|
|
|
|
macro_(isa, kLT8); \
|
|
|
|
FOR_EACH_COMMON_BLOCK(macro_, jit::isa_any)
|
|
|
|
macro_(isa, kEQ8); \
|
|
|
|
|
|
|
|
macro_(isa, kGT8LT16); \
|
|
|
|
#define FOR_EACH_ALL_BLOCK(macro_, isa) \
|
|
|
|
macro_(isa, kEQ16); \
|
|
|
|
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \
|
|
|
|
macro_(isa, kGT16)
|
|
|
|
macro_(isa, kGT16)
|
|
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_ISA_BLOCK(macro_) \
|
|
|
|
#define FOR_EACH_ISA_ALL_BLOCK(macro_) \
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx512f); \
|
|
|
|
FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx2); \
|
|
|
|
FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx); \
|
|
|
|
FOR_EACH_ALL_BLOCK(macro_, jit::avx) \
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::isa_any)
|
|
|
|
FOR_EACH_ALL_BLOCK(macro_, jit::isa_any)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/* VMUL JitKernel */
|
|
|
|
/* VMUL JitKernel */
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
@ -121,8 +120,8 @@ class VMulKernelImpl : public VMulKernel<T> {
|
|
|
|
platform::dynload::vdMul(n, x, y, z); \
|
|
|
|
platform::dynload::vdMul(n, x, y, z); \
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT);
|
|
|
|
FOR_EACH_ISA(VMUL_MKL_FLOAT, kGT16);
|
|
|
|
FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE);
|
|
|
|
FOR_EACH_ISA_BLOCK(VMUL_MKL_DOUBLE);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
#define VMUL_INTRI8_FLOAT(isa) \
|
|
|
|
#define VMUL_INTRI8_FLOAT(isa) \
|
|
|
|
@ -178,8 +177,8 @@ class VAddKernelImpl : public VAddKernel<T> {
|
|
|
|
platform::dynload::vdAdd(n, x, y, z); \
|
|
|
|
platform::dynload::vdAdd(n, x, y, z); \
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT);
|
|
|
|
FOR_EACH_ISA(VADD_MKL_FLOAT, kGT16);
|
|
|
|
FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE);
|
|
|
|
FOR_EACH_ISA_BLOCK(VADD_MKL_DOUBLE);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
#define VADD_INTRI8_FLOAT(isa) \
|
|
|
|
#define VADD_INTRI8_FLOAT(isa) \
|
|
|
|
@ -210,10 +209,9 @@ VADD_INTRI8_FLOAT(jit::avx512f);
|
|
|
|
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
|
|
|
|
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
|
|
|
|
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
|
|
|
|
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
|
|
|
|
|
|
|
|
|
|
|
|
#undef FOR_EACH_ISA_ALL_BLOCK
|
|
|
|
#undef FOR_EACH_ISA
|
|
|
|
#undef FOR_EACH_ALL_BLOCK
|
|
|
|
#undef FOR_EACH_BLOCK
|
|
|
|
#undef FOR_EACH_ISA_COMMON_BLOCK
|
|
|
|
#undef FOR_EACH_ISA_BLOCK
|
|
|
|
#undef FOR_EACH_COMMON_BLOCK
|
|
|
|
|
|
|
|
#undef REGISTER_BLAS_JITKERNEL
|
|
|
|
#undef REGISTER_BLAS_JITKERNEL
|
|
|
|
#undef DEFINE_WITH_DTYPE
|
|
|
|
#undef DEFINE_WITH_DTYPE
|
|
|
|
#undef SEARCH_ISA_BLOCK
|
|
|
|
#undef SEARCH_ISA_BLOCK
|
|
|
|
|