|
|
|
@ -36,35 +36,38 @@ KernelPool& KernelPool::Instance() {
|
|
|
|
|
static KernelPool g_jit_kernels;
|
|
|
|
|
return g_jit_kernels;
|
|
|
|
|
}
|
|
|
|
|
#define SEARCH_BLOCK(src, t, isa) \
|
|
|
|
|
if (d < AVX_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kLT8>; \
|
|
|
|
|
} else if (d == AVX_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kEQ8>; \
|
|
|
|
|
} else if (d == AVX512_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kEQ16>; \
|
|
|
|
|
} else { \
|
|
|
|
|
Compute = src<t, isa, kGT16>; \
|
|
|
|
|
#define SEARCH_BLOCK(src, t, isa) \
|
|
|
|
|
if (d < AVX_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kLT8>; \
|
|
|
|
|
} else if (d == AVX_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kEQ8>; \
|
|
|
|
|
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kGT8LT16>; \
|
|
|
|
|
} else if (d == AVX512_FLOAT_BLOCK) { \
|
|
|
|
|
Compute = src<t, isa, kEQ16>; \
|
|
|
|
|
} else { \
|
|
|
|
|
Compute = src<t, isa, kGT16>; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define SEARCH_ISA_BLOCK(src, t) \
|
|
|
|
|
if (jit::MayIUse(jit::avx512_common)) { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::avx512_common); \
|
|
|
|
|
} else if (jit::MayIUse(jit::avx2)) { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::avx2); \
|
|
|
|
|
} else if (jit::MayIUse(jit::avx)) { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::avx); \
|
|
|
|
|
} else { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::isa_any); \
|
|
|
|
|
#define SEARCH_ISA_BLOCK(src, t) \
|
|
|
|
|
if (jit::MayIUse(jit::avx512f)) { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::avx512f); \
|
|
|
|
|
} else if (jit::MayIUse(jit::avx2)) { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::avx2); \
|
|
|
|
|
} else if (jit::MayIUse(jit::avx)) { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::avx); \
|
|
|
|
|
} else { \
|
|
|
|
|
SEARCH_BLOCK(src, t, jit::isa_any); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_BLOCK(macro_, isa) \
|
|
|
|
|
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kEQ16) macro_(isa, kGT16)
|
|
|
|
|
// do not include lt8, eq8, eq16
|
|
|
|
|
#define FOR_EACH_COMMON_BLOCK(macro_, isa) \
|
|
|
|
|
macro_(isa, kGT8LT16) macro_(isa, kGT16)
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_ISA_BLOCK(macro_) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx512_common) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx2) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx) \
|
|
|
|
|
#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx512f) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx2) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::avx) \
|
|
|
|
|
FOR_EACH_BLOCK(macro_, jit::any)
|
|
|
|
|
|
|
|
|
|
#define VMUL_ANY \
|
|
|
|
@ -78,24 +81,56 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
static void VMulCompute<float, isa, block>(const int n, const float* x, \
|
|
|
|
|
const float* y, float* z) { \
|
|
|
|
|
platform::dynload::vsMul(n, x, y, z); \
|
|
|
|
|
#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VMulCompute<float, isa, block>(const int n, const float* x, \
|
|
|
|
|
const float* y, float* z) { \
|
|
|
|
|
platform::dynload::vsMul(n, x, y, z); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
static void VMulCompute<double, isa, block>(const int n, const double* x, \
|
|
|
|
|
const double* y, float* z) { \
|
|
|
|
|
platform::dynload::vdMul(n, x, y, z); \
|
|
|
|
|
#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VMulCompute<double, isa, block>(const int n, const double* x, \
|
|
|
|
|
const double* y, float* z) { \
|
|
|
|
|
platform::dynload::vdMul(n, x, y, z); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT)
|
|
|
|
|
FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE)
|
|
|
|
|
// TODO(TJ): add EQ8
|
|
|
|
|
FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT)
|
|
|
|
|
FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE)
|
|
|
|
|
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kLT8)
|
|
|
|
|
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ16)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// mkl > avx > for, ">" means better
|
|
|
|
|
#ifdef PADDLE_USE_MKLML
|
|
|
|
|
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ8)
|
|
|
|
|
#elif defined __AVX__
|
|
|
|
|
template <>
|
|
|
|
|
void VMulCompute<float, jit::avx, kEQ8>(const int n, const float* x,
|
|
|
|
|
const float* y, float* z) {
|
|
|
|
|
__m256 tmpx, tmpy;
|
|
|
|
|
tmpx = _mm256_loadu_ps(x);
|
|
|
|
|
tmpy = _mm256_loadu_ps(y);
|
|
|
|
|
tmpx = _mm256_mul_ps(tmpx, tmpy);
|
|
|
|
|
_mm256_storeu_ps(z, tmpx);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// avx2 > mkl > for
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
template <>
|
|
|
|
|
void VMulCompute<float, jit::avx2, kEQ8>(const int n, const float* x,
|
|
|
|
|
const float* y, float* z) {
|
|
|
|
|
__m256 tmpx, tmpy;
|
|
|
|
|
tmpx = _mm256_loadu_ps(x);
|
|
|
|
|
tmpy = _mm256_loadu_ps(y);
|
|
|
|
|
tmpx = _mm256_mul_ps(tmpx, tmpy);
|
|
|
|
|
_mm256_storeu_ps(z, tmpx);
|
|
|
|
|
}
|
|
|
|
|
#elif defined PADDLE_USE_MKLML
|
|
|
|
|
DEFINE_VMUL_COMPUTE_FLOAT(jit::avx2, kEQ8)
|
|
|
|
|
#endif
|
|
|
|
|
// TODO(TJ): test and complete avx512
|
|
|
|
|
|
|
|
|
|
#undef DEFINE_VMUL_COMPUTE_FLOAT
|
|
|
|
|
#undef DEFINE_VMUL_COMPUTE_DOUBLE
|
|
|
|
@ -142,8 +177,8 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
|
|
|
|
|
: Kernel(), d_(d) {
|
|
|
|
|
d2_ = d * 2;
|
|
|
|
|
d3_ = d * 3;
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx512_common)) {
|
|
|
|
|
math::VecActivations<float, platform::jit::avx512_common> act_functor;
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx512f)) {
|
|
|
|
|
math::VecActivations<float, platform::jit::avx512f> act_functor;
|
|
|
|
|
act_gate_ = act_functor(act_gate_str);
|
|
|
|
|
act_cell_ = act_functor(act_cell_str);
|
|
|
|
|
act_cand_ = act_functor(act_cand_str);
|
|
|
|
|