|
|
|
@ -168,24 +168,26 @@ void ReluJitCode::generate() {
|
|
|
|
|
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
|
|
|
|
|
|
|
|
|
|
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
|
|
|
|
|
static const float exp_float_consts[] ALIGN32 = {
|
|
|
|
|
REPEAT_8TIMES(1.f),
|
|
|
|
|
REPEAT_8TIMES(2.f),
|
|
|
|
|
REPEAT_8TIMES(0.5f),
|
|
|
|
|
REPEAT_8TIMES(EXP_HIG),
|
|
|
|
|
REPEAT_8TIMES(EXP_LOW),
|
|
|
|
@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
|
ymm_t ymm_fy = ymm_t(3);
|
|
|
|
|
ymm_t ymm_mask = ymm_t(4);
|
|
|
|
|
ymm_t ymm_tmp = ymm_t(5);
|
|
|
|
|
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
|
push(reg_ptr_global);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
|
|
|
|
@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() {
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool VTanhJitCode::init(int d) {
|
|
|
|
|
return MayIUse(avx) && d == 8; // only 8 yet
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
|
// y = 2 / (1 + e^(-2x)) - 1
|
|
|
|
|
// use ymm2, ymm3
|
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
|
ymm_t ymm_tmp = ymm_t(2);
|
|
|
|
|
ymm_t ymm_zero = ymm_t(3);
|
|
|
|
|
push(reg_ptr_global);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
|
|
|
|
|
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
|
|
|
|
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
|
|
|
|
|
vmulps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
|
exp_ymm(ymm_src, ymm_dst);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
|
|
|
|
|
vdivps(ymm_dst, ymm_tmp, ymm_dst);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
|
|
|
|
|
vsubps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VTanhJitCode::generate() {
|
|
|
|
|
int offset = 0;
|
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
|
vtanh_ymm(ymm_src, ymm_dst);
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace gen
|
|
|
|
|
} // namespace jitkernel
|
|
|
|
|
} // namespace math
|
|
|
|
|