|
|
|
@ -151,6 +151,132 @@ void ReluJitCode::generate() {
|
|
|
|
|
}
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool VExpJitCode::init(int d) {
|
|
|
|
|
return MayIUse(avx) && d == 8; // only 8 yet
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define ALIGN32 __attribute__((aligned(32)))
|
|
|
|
|
#define EXP_HIG 88.3762626647949f
|
|
|
|
|
#define EXP_LOW -88.3762626647949f
|
|
|
|
|
#define CEPHES_LOG2EF 1.44269504088896341
|
|
|
|
|
#define CEPHES_EXP_C1 0.693359375
|
|
|
|
|
#define CEPHES_EXP_C2 -2.12194440e-4
|
|
|
|
|
#define CEPHES_EXP_P0 1.9875691500E-4
|
|
|
|
|
#define CEPHES_EXP_P1 1.3981999507E-3
|
|
|
|
|
#define CEPHES_EXP_P2 8.3334519073E-3
|
|
|
|
|
#define CEPHES_EXP_P3 4.1665795894E-2
|
|
|
|
|
#define CEPHES_EXP_P4 1.6666665459E-1
|
|
|
|
|
#define CEPHES_EXP_P5 5.0000001201E-1
|
|
|
|
|
|
|
|
|
|
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
|
|
|
|
|
|
|
|
|
|
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
|
|
|
|
|
static const float exp_float_consts[] ALIGN32 = {
|
|
|
|
|
REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f),
|
|
|
|
|
REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW),
|
|
|
|
|
REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1),
|
|
|
|
|
REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0),
|
|
|
|
|
REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2),
|
|
|
|
|
REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4),
|
|
|
|
|
REPEAT_8TIMES(CEPHES_EXP_P5)};
|
|
|
|
|
|
|
|
|
|
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
|
|
|
|
|
static int g_tmp_mem[16] ALIGN32 = {0};
|
|
|
|
|
|
|
|
|
|
void VExpJitCode::generate() {
|
|
|
|
|
preCode();
|
|
|
|
|
// push some?
|
|
|
|
|
// in: ymm0, out: ymm1
|
|
|
|
|
// use ymm 0~5 (and ymm 14~15 if avx only)
|
|
|
|
|
int offset = 0;
|
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
|
|
|
|
|
vminps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
|
|
|
|
|
vmaxps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
|
// express exp(x) as exp(g + n*log(2))
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
|
|
|
|
|
vmulps(ymm_fx, ymm_src, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
|
|
|
|
|
vaddps(ymm_fx, ymm_fx, ymm_tmp);
|
|
|
|
|
vroundps(ymm_fy, ymm_fx, 0x01);
|
|
|
|
|
// if greater, substract 1
|
|
|
|
|
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vandps(ymm_mask, ymm_mask, ymm_tmp);
|
|
|
|
|
vsubps(ymm_fx, ymm_fy, ymm_mask);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
|
|
|
|
|
vmulps(ymm_fy, ymm_fx, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
|
|
|
|
|
vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
|
|
|
|
|
vsubps(ymm_src, ymm_src, ymm_fy);
|
|
|
|
|
vsubps(ymm_src, ymm_src, ymm_z);
|
|
|
|
|
vmulps(ymm_z, ymm_src, ymm_src);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
|
|
|
|
|
vmulps(ymm_dst, ymm_src, ymm_tmp);
|
|
|
|
|
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
|
|
|
|
|
i += (AVX_FLOAT_BLOCK * sizeof(float))) {
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
vmulps(ymm_dst, ymm_dst, ymm_src);
|
|
|
|
|
}
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
vmulps(ymm_dst, ymm_dst, ymm_z);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_src);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
|
|
|
|
|
// build 2^n
|
|
|
|
|
ymm_t ymm_int = ymm_fx;
|
|
|
|
|
vcvttps2dq(ymm_int, ymm_fx);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
|
|
|
|
|
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
if (MayIUse(avx2)) {
|
|
|
|
|
vpaddd(ymm_int, ymm_int, ymm_tmp);
|
|
|
|
|
vpslld(ymm_int, ymm_int, 23);
|
|
|
|
|
} else if (MayIUse(avx)) {
|
|
|
|
|
// use ymm_int, ymm_tmp and reg_ptr_global
|
|
|
|
|
xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int
|
|
|
|
|
xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(g_tmp_mem));
|
|
|
|
|
vmovdqa(ptr[reg_ptr_global], ymm_int);
|
|
|
|
|
vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
|
|
|
|
|
vpaddd(xtmp1, xtmp1, xtmp2);
|
|
|
|
|
vpslld(xtmp1, xtmp1, 23);
|
|
|
|
|
vmovdqa(ptr[reg_ptr_global], xtmp1);
|
|
|
|
|
// next 128bits
|
|
|
|
|
vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]);
|
|
|
|
|
vmovdqa(xtmp2,
|
|
|
|
|
ptr[reg_ptr_global +
|
|
|
|
|
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
|
|
|
|
|
vpaddd(xtmp1, xtmp1, xtmp2);
|
|
|
|
|
vpslld(xtmp1, xtmp1, 23);
|
|
|
|
|
vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
|
|
|
|
|
// load out
|
|
|
|
|
vmovdqa(ymm_int, ptr[reg_ptr_global]);
|
|
|
|
|
}
|
|
|
|
|
vmulps(ymm_dst, ymm_dst, ymm_int);
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
|
|
|
|
|
// ret();
|
|
|
|
|
postCode();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace gen
|
|
|
|
|
} // namespace jitkernel
|
|
|
|
|
} // namespace math
|
|
|
|
|