|
|
|
@ -20,10 +20,6 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/dynload/mklml.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
#include <immintrin.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
@ -66,14 +62,18 @@ namespace detail {
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
|
|
|
|
|
#if defined(_WIN32)
|
|
|
|
|
#define ALIGN32 __declspec(align(32))
|
|
|
|
|
#else
|
|
|
|
|
#define ALIGN32 __attribute__((aligned(32)))
|
|
|
|
|
#endif // _WIN32
|
|
|
|
|
|
|
|
|
|
#define _PS256_CONST(Name, Val) \
|
|
|
|
|
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
|
|
|
|
|
static const float ALIGN32 _ps256_##Name[8] = {Val, Val, Val, Val, \
|
|
|
|
|
Val, Val, Val, Val}
|
|
|
|
|
|
|
|
|
|
#define _PI256_CONST(Name, Val) \
|
|
|
|
|
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
|
|
|
|
|
static const int ALIGN32 _pi256_##Name[8] = {Val, Val, Val, Val, \
|
|
|
|
|
Val, Val, Val, Val}
|
|
|
|
|
|
|
|
|
|
_PI256_CONST(0x7f, 0x7f);
|
|
|
|
@ -98,7 +98,7 @@ typedef union imm_xmm_union {
|
|
|
|
|
|
|
|
|
|
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
|
|
|
|
|
{ \
|
|
|
|
|
imm_xmm_union u ALIGN32; \
|
|
|
|
|
imm_xmm_union ALIGN32 u; \
|
|
|
|
|
u.imm = imm_; \
|
|
|
|
|
xmm0_ = u.xmm[0]; \
|
|
|
|
|
xmm1_ = u.xmm[1]; \
|
|
|
|
@ -106,7 +106,7 @@ typedef union imm_xmm_union {
|
|
|
|
|
|
|
|
|
|
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
|
|
|
|
|
{ \
|
|
|
|
|
imm_xmm_union u ALIGN32; \
|
|
|
|
|
imm_xmm_union ALIGN32 u; \
|
|
|
|
|
u.xmm[0] = xmm0_; \
|
|
|
|
|
u.xmm[1] = xmm1_; \
|
|
|
|
|
imm_ = u.imm; \
|
|
|
|
@ -508,12 +508,14 @@ class VTanhKernelImpl : public VTanhKernel<T> {
|
|
|
|
|
vaddbias_->Compute(-1.f, y, y); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef __WIN32
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
|
|
|
|
|
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
|
|
|
|
|
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
|
|
|
|
|
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
|
|
|
|
|
#endif
|
|
|
|
|
#endif // AVX
|
|
|
|
|
#endif // WIN32
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
|
|
|
|
|
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
|
|
|
|
|