add vtanh jitcode of size 8

panyx0718-patch-1
tensor-tang 6 years ago
parent 046374bcd1
commit 6a159071b6

@ -168,24 +168,26 @@ void ReluJitCode::generate() {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float)
static const float exp_float_consts[] ALIGN32 = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
ymm_t ymm_fy = ymm_t(3);
ymm_t ymm_mask = ymm_t(4);
ymm_t ymm_tmp = ymm_t(5);
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() {
ret();
}
bool VTanhJitCode::init(int d) {
return MayIUse(avx) && d == 8; // only 8 yet
}
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
// y = 2 / (1 + e^(-2x)) - 1
// use ymm2, ymm3
reg64_t reg_ptr_global = rax;
ymm_t ymm_tmp = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vxorps(ymm_zero, ymm_zero, ymm_zero);
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
vmulps(ymm_src, ymm_src, ymm_tmp);
exp_ymm(ymm_src, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vaddps(ymm_dst, ymm_dst, ymm_tmp);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
vdivps(ymm_dst, ymm_tmp, ymm_dst);
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
vsubps(ymm_dst, ymm_dst, ymm_tmp);
pop(reg_ptr_global);
}
void VTanhJitCode::generate() {
int offset = 0;
vmovups(ymm_src, ptr[param1 + offset]);
vtanh_ymm(ymm_src, ymm_dst);
vmovups(ptr[param2 + offset], ymm_dst);
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math

@ -149,6 +149,26 @@ class VSigmoidJitCode : public VExpJitCode {
ymm_t ymm_dst = ymm_t(1);
};
class VTanhJitCode : public VExpJitCode {
public:
DECLARE_JIT_CODE(VTanhJitCode);
explicit VTanhJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
// compute sigmoid with ymm
void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
ymm_t ymm_src = ymm_t(0);
ymm_t ymm_dst = ymm_t(1);
};
} // namespace gen
} // namespace jitkernel
} // namespace math

@ -132,6 +132,7 @@ template <typename T>
class VTanhKernel : public VActKernel<T> {
public:
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T>

File diff suppressed because it is too large Load Diff

@ -322,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->ComputeDeprecated(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();

Loading…
Cancel
Save