|
|
|
@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define INTRIAVX_FLOAT(isa, block) \
|
|
|
|
|
#define INTRIAVX_FLOAT(isa, jit_block) \
|
|
|
|
|
template <> \
|
|
|
|
|
LayerNormKernelImpl<float, isa, block>::LayerNormKernelImpl(int right) \
|
|
|
|
|
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
|
|
|
|
|
: LayerNormKernel<float>() { \
|
|
|
|
|
this->num_ = right; \
|
|
|
|
|
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
|
|
|
|
|
this->end_ = this->num_ - this->rest_; \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void LayerNormKernelImpl<float, platform::avx, block>::Compute( \
|
|
|
|
|
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
|
|
|
|
|
float* x, float* out, float* mean, float* var, const float* scale, \
|
|
|
|
|
const float* bias, int height, const float epsilon) const { \
|
|
|
|
|
__m256 sum; \
|
|
|
|
@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
|
|
|
|
|
__m256 tmp; \
|
|
|
|
|
size_t offset; \
|
|
|
|
|
size_t j; \
|
|
|
|
|
size_t block = YMM_FLOAT_BLOCK; \
|
|
|
|
|
__m256 reverse_num_vec = \
|
|
|
|
|
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
|
|
|
|
|
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
|
|
|
|
@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx, kGT8LT16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx, kEQ16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx, kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx2, kEQ8);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx2, kEQ16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx2, kGT16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx512f, kEQ8);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx512f, kGT8LT16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx512f, kEQ16);
|
|
|
|
|
INTRIAVX_FLOAT(platform::avx512f, kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#undef INTRIAVX_FLOAT
|
|
|
|
|