|
|
@ -118,40 +118,6 @@ void VXXJitCode::generate() {
|
|
|
|
ret();
|
|
|
|
ret();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool ReluJitCode::init(int d) { return MayIUse(avx); }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ReluJitCode::generate() {
|
|
|
|
|
|
|
|
int offset = 0;
|
|
|
|
|
|
|
|
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
|
|
|
|
|
|
|
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
|
|
|
|
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(ymm_dst, ymm_zero, ymm_src);
|
|
|
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
|
|
|
offset += sizeof(float) * AVX_FLOAT_BLOCK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
int rest = num_ % AVX_FLOAT_BLOCK;
|
|
|
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
|
|
|
vmovups(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
|
|
|
offset += sizeof(float) * 4;
|
|
|
|
|
|
|
|
rest -= 4;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (rest >= 2) {
|
|
|
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
|
|
|
vmovq(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
|
|
|
offset += sizeof(float) * 2;
|
|
|
|
|
|
|
|
rest -= 2;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
|
|
|
vmovss(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ret();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define ALIGN32 __attribute__((aligned(32)))
|
|
|
|
#define ALIGN32 __attribute__((aligned(32)))
|
|
|
|
#define EXP_HIG 88.3762626647949f
|
|
|
|
#define EXP_HIG 88.3762626647949f
|
|
|
|
#define EXP_LOW -88.3762626647949f
|
|
|
|
#define EXP_LOW -88.3762626647949f
|
|
|
@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
|
|
|
|
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
|
|
|
|
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
|
|
|
|
static int g_tmp_mem[16] ALIGN32 = {0};
|
|
|
|
static int g_tmp_mem[16] ALIGN32 = {0};
|
|
|
|
|
|
|
|
|
|
|
|
bool VExpJitCode::init(int d) {
|
|
|
|
bool VActJitCode::init(int d, operand_type type) {
|
|
|
|
return MayIUse(avx) && d == 8; // only 8 yet
|
|
|
|
bool ok = MayIUse(avx);
|
|
|
|
|
|
|
|
if (type == operand_type::relu) {
|
|
|
|
|
|
|
|
return ok;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return ok && d == 8; // only 8 yet
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
|
|
|
|
// use reg rax and ymm 2~5
|
|
|
|
vmaxps(ymm_dst, ymm_zero, ymm_src);
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
}
|
|
|
|
ymm_t ymm_fx = ymm_t(2);
|
|
|
|
|
|
|
|
ymm_t ymm_fy = ymm_t(3);
|
|
|
|
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
ymm_t ymm_mask = ymm_t(4);
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
ymm_t ymm_tmp = ymm_t(5);
|
|
|
|
|
|
|
|
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
|
|
|
|
// check all idx can not equal
|
|
|
|
|
|
|
|
ymm_t ymm_fx = ymm_t(fx_idx);
|
|
|
|
|
|
|
|
ymm_t ymm_fy = ymm_t(fy_idx);
|
|
|
|
|
|
|
|
ymm_t ymm_mask = ymm_t(mask_idx);
|
|
|
|
|
|
|
|
ymm_t ymm_tmp = ymm_t(tmp_idx);
|
|
|
|
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
push(reg_ptr_global);
|
|
|
|
push(reg_ptr_global);
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
|
|
|
@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void VExpJitCode::generate() {
|
|
|
|
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
int offset = 0;
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
// y = 1 / (1 + e^-x)
|
|
|
|
exp_ymm(ymm_src, ymm_dst);
|
|
|
|
ymm_t ymm_tmp = ymm_t(tmp_idx);
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
|
|
|
ret();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool VSigmoidJitCode::init(int d) {
|
|
|
|
|
|
|
|
return MayIUse(avx) && d == 8; // only 8 yet
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
|
|
|
|
// use ymm2
|
|
|
|
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
ymm_t ymm_tmp = ymm_t(2);
|
|
|
|
|
|
|
|
push(reg_ptr_global);
|
|
|
|
push(reg_ptr_global);
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
|
|
|
@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
vmaxps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
vmaxps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
|
|
|
|
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
|
|
|
|
vsubps(ymm_src, ymm_tmp, ymm_src);
|
|
|
|
vsubps(ymm_src, ymm_tmp, ymm_src);
|
|
|
|
exp_ymm(ymm_src, ymm_dst);
|
|
|
|
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
vdivps(ymm_dst, ymm_tmp, ymm_dst);
|
|
|
|
vdivps(ymm_dst, ymm_tmp, ymm_dst);
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void VSigmoidJitCode::generate() {
|
|
|
|
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
int offset = 0;
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
sigmoid_ymm(ymm_src, ymm_dst);
|
|
|
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
|
|
|
ret();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool VTanhJitCode::init(int d) {
|
|
|
|
|
|
|
|
return MayIUse(avx) && d == 8; // only 8 yet
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
|
|
|
|
// y = 2 / (1 + e^(-2x)) - 1
|
|
|
|
// y = 2 / (1 + e^(-2x)) - 1
|
|
|
|
// use ymm2, ymm3
|
|
|
|
ymm_t ymm_tmp = ymm_t(tmp_idx);
|
|
|
|
|
|
|
|
ymm_t ymm_zero = ymm_t(mask_idx);
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
ymm_t ymm_tmp = ymm_t(2);
|
|
|
|
|
|
|
|
ymm_t ymm_zero = ymm_t(3);
|
|
|
|
|
|
|
|
push(reg_ptr_global);
|
|
|
|
push(reg_ptr_global);
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
|
|
|
|
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
|
|
|
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
|
|
|
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
|
|
|
|
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
|
|
|
|
vmulps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
vmulps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
exp_ymm(ymm_src, ymm_dst);
|
|
|
|
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
|
|
|
@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void VTanhJitCode::generate() {
|
|
|
|
void VActJitCode::generate() {
|
|
|
|
|
|
|
|
xmm_t xmm_zero = xmm_t(2);
|
|
|
|
|
|
|
|
ymm_t ymm_zero = ymm_t(2);
|
|
|
|
|
|
|
|
if (type_ == operand_type::relu) {
|
|
|
|
|
|
|
|
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
|
|
|
|
|
|
|
}
|
|
|
|
int offset = 0;
|
|
|
|
int offset = 0;
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
|
|
|
|
vtanh_ymm(ymm_src, ymm_dst);
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
switch (type_) {
|
|
|
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
|
|
|
relu_ymm(ymm_dst, ymm_src, ymm_zero);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
|
|
|
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
|
|
|
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
|
|
|
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
|
|
|
offset += sizeof(float) * AVX_FLOAT_BLOCK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (type_ != operand_type::relu) {
|
|
|
|
|
|
|
|
// TODO(TJ): remove me
|
|
|
|
|
|
|
|
ret();
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
int rest = num_ % AVX_FLOAT_BLOCK;
|
|
|
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
|
|
|
vmovups(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
|
|
|
offset += sizeof(float) * 4;
|
|
|
|
|
|
|
|
rest -= 4;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (rest >= 2) {
|
|
|
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
|
|
|
vmovq(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
|
|
|
offset += sizeof(float) * 2;
|
|
|
|
|
|
|
|
rest -= 2;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
|
|
|
vmovss(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
|
|
|
}
|
|
|
|
ret();
|
|
|
|
ret();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|