|
|
|
@ -235,6 +235,7 @@ INTRI16_FLOAT(jit::avx512f);
|
|
|
|
|
#undef INTRI16_FLOAT
|
|
|
|
|
#undef INTRI_GT8LT16_FLOAT
|
|
|
|
|
#undef INTRI_GT16_FLOAT
|
|
|
|
|
#undef INTRI_VSIGMOID
|
|
|
|
|
|
|
|
|
|
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
|
|
|
|
|
p = std::dynamic_pointer_cast<ker<dtype>>( \
|
|
|
|
@ -243,6 +244,118 @@ INTRI16_FLOAT(jit::avx512f);
|
|
|
|
|
REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
|
|
|
|
|
JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
|
|
|
|
|
|
|
|
|
|
/* VTanh JitKernel */
|
|
|
|
|
template <typename T, jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
class VTanhKernelImpl : public VTanhKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
|
|
|
|
|
vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
|
|
|
|
|
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
|
|
|
|
|
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
|
|
|
|
|
}
|
|
|
|
|
void Compute(const int n, const T* x, T* y) const override {
|
|
|
|
|
vscal_->Compute(n, static_cast<T>(2), x, y);
|
|
|
|
|
vsigmoid_->Compute(n, y, y);
|
|
|
|
|
vscal_->Compute(n, static_cast<T>(2), y);
|
|
|
|
|
vaddbias_->Compute(n, static_cast<T>(-1), y, y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<const VScalKernel<T>> vscal_;
|
|
|
|
|
std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
|
|
|
|
|
std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define INTRI_VTANH(tmp) \
|
|
|
|
|
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
|
|
|
|
|
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
|
|
|
|
|
tmp = detail::Exp(tmp); \
|
|
|
|
|
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
|
|
|
|
|
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
|
|
|
|
|
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
|
|
|
|
|
float* y) const { \
|
|
|
|
|
__m256 tmp = _mm256_loadu_ps(x); \
|
|
|
|
|
INTRI_VTANH(tmp); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VTanhKernelImpl<float, isa, kEQ16>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
__m256 tmp0 = _mm256_loadu_ps(x); \
|
|
|
|
|
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
|
|
|
|
|
INTRI_VTANH(tmp0); \
|
|
|
|
|
INTRI_VTANH(tmp1); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp0); \
|
|
|
|
|
_mm256_storeu_ps(y + 8, tmp1); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI_GT8LT16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
__m256 tmp = _mm256_loadu_ps(x); \
|
|
|
|
|
INTRI_VTANH(tmp); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp); \
|
|
|
|
|
x += AVX_FLOAT_BLOCK; \
|
|
|
|
|
y += AVX_FLOAT_BLOCK; \
|
|
|
|
|
const int rest = n - AVX_FLOAT_BLOCK; \
|
|
|
|
|
vscal_->Compute(rest, 2.f, x, y); \
|
|
|
|
|
vsigmoid_->Compute(rest, y, y); \
|
|
|
|
|
vscal_->Compute(rest, 2.f, y); \
|
|
|
|
|
vaddbias_->Compute(rest, -1.f, y, y); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI_GT16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VTanhKernelImpl<float, isa, kGT16>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
const int rest = n % AVX_FLOAT_BLOCK; \
|
|
|
|
|
const int end = n - rest; \
|
|
|
|
|
for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \
|
|
|
|
|
__m256 tmp = _mm256_loadu_ps(x + i); \
|
|
|
|
|
INTRI_VTANH(tmp); \
|
|
|
|
|
_mm256_storeu_ps(y + i, tmp); \
|
|
|
|
|
} \
|
|
|
|
|
x += end; \
|
|
|
|
|
y += end; \
|
|
|
|
|
vscal_->Compute(rest, 2.f, x, y); \
|
|
|
|
|
vsigmoid_->Compute(rest, y, y); \
|
|
|
|
|
vscal_->Compute(rest, 2.f, y); \
|
|
|
|
|
vaddbias_->Compute(rest, -1.f, y, y); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
INTRI8_FLOAT(jit::avx);
|
|
|
|
|
INTRI16_FLOAT(jit::avx);
|
|
|
|
|
INTRI_GT8LT16_FLOAT(jit::avx);
|
|
|
|
|
INTRI_GT16_FLOAT(jit::avx);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRI8_FLOAT(jit::avx2);
|
|
|
|
|
INTRI16_FLOAT(jit::avx2);
|
|
|
|
|
// maybe use avx at gt8lt16 and gt16
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
|
INTRI8_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI16_FLOAT(jit::avx512f);
|
|
|
|
|
// maybe use avx at gt8lt16 and gt16
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef INTRI16_FLOAT
|
|
|
|
|
#undef INTRI_GT8LT16_FLOAT
|
|
|
|
|
#undef INTRI_GT16_FLOAT
|
|
|
|
|
#undef INTRI_VTANH
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL_ARGS(vtanh, VTanhKernel, JITKERNEL_DECLARE, JITKERNEL_KEY,
|
|
|
|
|
JITKERNEL_NEW_ACT_IMPL);
|
|
|
|
|
|
|
|
|
|
#undef JITKERNEL_NEW_ACT_IMPL
|
|
|
|
|
|
|
|
|
|
} // namespace jitkernel
|
|
|
|
|