Fix the issue to run on AVX2 and AVX512F machines (#14851)

test=develop
ce_debug
Yihua Xu 6 years ago committed by tensor-tang
parent 45fb357b60
commit acc6ae49b1

@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
}
};
#define INTRIAVX_FLOAT(isa, block) \
#define INTRIAVX_FLOAT(isa, jit_block) \
template <> \
LayerNormKernelImpl<float, isa, block>::LayerNormKernelImpl(int right) \
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
: LayerNormKernel<float>() { \
this->num_ = right; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
this->end_ = this->num_ - this->rest_; \
} \
template <> \
void LayerNormKernelImpl<float, platform::avx, block>::Compute( \
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \
__m256 sum; \
@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
__m256 tmp; \
size_t offset; \
size_t j; \
size_t block = YMM_FLOAT_BLOCK; \
__m256 reverse_num_vec = \
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8);
INTRIAVX_FLOAT(platform::avx, kGT8LT16);
INTRIAVX_FLOAT(platform::avx, kEQ16);
INTRIAVX_FLOAT(platform::avx, kGT16);
#endif
#ifdef __AVX2__
INTRIAVX_FLOAT(platform::avx2, kEQ8);
INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX_FLOAT(platform::avx2, kEQ16);
INTRIAVX_FLOAT(platform::avx2, kGT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ8);
INTRIAVX_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ16);
INTRIAVX_FLOAT(platform::avx512f, kGT16);
#endif
#undef INTRIAVX_FLOAT

Loading…
Cancel
Save