|
|
|
@ -13,9 +13,6 @@ limitations under the License. */
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
#include <immintrin.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -121,7 +118,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
|
|
|
|
|
if (rest_ != 0) { \
|
|
|
|
|
j = offset + this->num_ - block; \
|
|
|
|
|
tmp = _mm256_loadu_ps((const float*)x + j); \
|
|
|
|
|
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, (__m256)mask_vec); \
|
|
|
|
|
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
|
|
|
|
|
sum = _mm256_add_ps(sum, tmp); \
|
|
|
|
|
} \
|
|
|
|
|
hi = _mm256_extractf128_ps(sum, 1); \
|
|
|
|
@ -145,7 +142,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
|
|
|
|
|
j = offset + this->num_ - block; \
|
|
|
|
|
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
|
|
|
|
|
tmp = _mm256_mul_ps(tmp, tmp); \
|
|
|
|
|
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, (__m256)mask_vec); \
|
|
|
|
|
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
|
|
|
|
|
sum = _mm256_add_ps(sum, tmp); \
|
|
|
|
|
} \
|
|
|
|
|
hi = _mm256_extractf128_ps(sum, 1); \
|
|
|
|
|