|
|
|
@ -132,6 +132,111 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
|
|
|
|
|
std::shared_ptr<const VExpKernel<T>> vexp_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define INTRI_SIGMOID(tmp, min, max) \
|
|
|
|
|
tmp = _mm256_max_ps(tmp, min); \
|
|
|
|
|
tmp = _mm256_min_ps(tmp, max); \
|
|
|
|
|
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
|
|
|
|
|
tmp = detail::Exp(tmp); \
|
|
|
|
|
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
|
|
|
|
|
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
|
|
|
|
|
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
|
|
|
|
|
__m256 tmp = _mm256_loadu_ps(x); \
|
|
|
|
|
INTRI_SIGMOID(tmp, min, max); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
|
|
|
|
|
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
|
|
|
|
|
__m256 tmp0 = _mm256_loadu_ps(x); \
|
|
|
|
|
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
|
|
|
|
|
INTRI_SIGMOID(tmp0, min, max); \
|
|
|
|
|
INTRI_SIGMOID(tmp1, min, max); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp0); \
|
|
|
|
|
_mm256_storeu_ps(y + 8, tmp1); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI_GT8LT16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VSigmoidKernelImpl<float, isa, kGT8LT16>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
|
|
|
|
|
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
|
|
|
|
|
__m256 tmp = _mm256_loadu_ps(x); \
|
|
|
|
|
INTRI_SIGMOID(tmp, min, max); \
|
|
|
|
|
_mm256_storeu_ps(y, tmp); \
|
|
|
|
|
const float min_ = SIGMOID_THRESHOLD_MIN; \
|
|
|
|
|
const float max_ = SIGMOID_THRESHOLD_MAX; \
|
|
|
|
|
for (int i = AVX_FLOAT_BLOCK; i < n; ++i) { \
|
|
|
|
|
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
|
|
|
|
|
y[i] = 0.f - y[i]; \
|
|
|
|
|
} \
|
|
|
|
|
vexp_->Compute(n - AVX_FLOAT_BLOCK, y + AVX_FLOAT_BLOCK, \
|
|
|
|
|
y + AVX_FLOAT_BLOCK); \
|
|
|
|
|
for (int i = AVX_FLOAT_BLOCK; i < n; ++i) { \
|
|
|
|
|
y[i] = 1.f / (1.f + y[i]); \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI_GT16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VSigmoidKernelImpl<float, isa, kGT16>::Compute( \
|
|
|
|
|
const int n, const float* x, float* y) const { \
|
|
|
|
|
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
|
|
|
|
|
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
|
|
|
|
|
const int rest = n % AVX_FLOAT_BLOCK; \
|
|
|
|
|
const int end = n - rest; \
|
|
|
|
|
for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \
|
|
|
|
|
__m256 tmp = _mm256_loadu_ps(x + i); \
|
|
|
|
|
INTRI_SIGMOID(tmp, min, max); \
|
|
|
|
|
_mm256_storeu_ps(y + i, tmp); \
|
|
|
|
|
} \
|
|
|
|
|
const float min_ = SIGMOID_THRESHOLD_MIN; \
|
|
|
|
|
const float max_ = SIGMOID_THRESHOLD_MAX; \
|
|
|
|
|
for (int i = end; i < n; ++i) { \
|
|
|
|
|
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
|
|
|
|
|
y[i] = 0.f - y[i]; \
|
|
|
|
|
} \
|
|
|
|
|
vexp_->Compute(rest, y + end, y + end); \
|
|
|
|
|
for (int i = end; i < n; ++i) { \
|
|
|
|
|
y[i] = 1.f / (1.f + y[i]); \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
INTRI8_FLOAT(jit::avx);
|
|
|
|
|
INTRI16_FLOAT(jit::avx);
|
|
|
|
|
INTRI_GT8LT16_FLOAT(jit::avx);
|
|
|
|
|
INTRI_GT16_FLOAT(jit::avx);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRI8_FLOAT(jit::avx2);
|
|
|
|
|
INTRI16_FLOAT(jit::avx2);
|
|
|
|
|
INTRI_GT8LT16_FLOAT(jit::avx2);
|
|
|
|
|
INTRI_GT16_FLOAT(jit::avx2);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
|
INTRI8_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI16_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI_GT8LT16_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI_GT16_FLOAT(jit::avx512f);
|
|
|
|
|
#endif
|
|
|
|
|
// TODO(TJ): eq16 test and complete avx512
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef INTRI16_FLOAT
|
|
|
|
|
#undef INTRI_GT8LT16_FLOAT
|
|
|
|
|
#undef INTRI_GT16_FLOAT
|
|
|
|
|
|
|
|
|
|
#define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
|
|
|
|
|
p = std::dynamic_pointer_cast<ker<dtype>>( \
|
|
|
|
|
std::make_shared<ker##Impl<dtype, isa, k>>(d))
|
|
|
|
@ -140,6 +245,7 @@ REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
|
|
|
|
|
JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
|
|
|
|
|
|
|
|
|
|
#undef JITKERNEL_NEW_ACT_IMPL
|
|
|
|
|
|
|
|
|
|
} // namespace jitkernel
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|