|
|
|
@ -81,10 +81,10 @@ void VXXJitCode::generate() {
|
|
|
|
|
}
|
|
|
|
|
if (rest >= 2) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovups(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
vmovq(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovups(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
vmovq(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (type_ == operand_type::mul) {
|
|
|
|
|
vmulps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
@ -100,10 +100,10 @@ void VXXJitCode::generate() {
|
|
|
|
|
}
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovups(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
vmovss(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovups(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
vmovss(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (type_ == operand_type::mul) {
|
|
|
|
|
vmulss(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) {
|
|
|
|
|
return ok;
|
|
|
|
|
} else if (type == operand_type::exp) {
|
|
|
|
|
// exp is slower than mkl when d >= 256
|
|
|
|
|
return ok && d % 8 == 0 && d < 256;
|
|
|
|
|
return ok; //&& d % 4 == 0 && d < 256;
|
|
|
|
|
} else {
|
|
|
|
|
// TODO(TJ): support more
|
|
|
|
|
return ok && d % 8 == 0;
|
|
|
|
@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
|
|
|
|
|
vmaxps(ymm_dst, ymm_zero, ymm_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
|
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::exp_xmm(xmm_t& ymm_dst, xmm_t& ymm_src, int fx_idx,
|
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
|
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
|
// check all idx can not equal
|
|
|
|
|
xmm_t ymm_fx = xmm_t(fx_idx);
|
|
|
|
|
xmm_t ymm_fy = xmm_t(fy_idx);
|
|
|
|
|
xmm_t ymm_mask = xmm_t(mask_idx);
|
|
|
|
|
xmm_t ymm_tmp = xmm_t(tmp_idx);
|
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
|
push(reg_ptr_global);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
|
|
|
|
|
vminps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
|
|
|
|
|
vmaxps(ymm_src, ymm_src, ymm_tmp);
|
|
|
|
|
// express exp(x) as exp(g + n*log(2))
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
|
|
|
|
|
vmulps(ymm_fx, ymm_src, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
|
|
|
|
|
vaddps(ymm_fx, ymm_fx, ymm_tmp);
|
|
|
|
|
vroundps(ymm_fy, ymm_fx, 0x01);
|
|
|
|
|
// if greater, substract 1
|
|
|
|
|
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vandps(ymm_mask, ymm_mask, ymm_tmp);
|
|
|
|
|
vsubps(ymm_fx, ymm_fy, ymm_mask);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
|
|
|
|
|
vmulps(ymm_fy, ymm_fx, ymm_tmp);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
|
|
|
|
|
xmm_t ymm_z = xmm_t(ymm_mask.getIdx());
|
|
|
|
|
vmulps(ymm_z, ymm_fx, ymm_tmp);
|
|
|
|
|
vsubps(ymm_src, ymm_src, ymm_fy);
|
|
|
|
|
vsubps(ymm_src, ymm_src, ymm_z);
|
|
|
|
|
vmulps(ymm_z, ymm_src, ymm_src);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
|
|
|
|
|
vmulps(ymm_dst, ymm_src, ymm_tmp);
|
|
|
|
|
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
|
|
|
|
|
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
vmulps(ymm_dst, ymm_dst, ymm_src);
|
|
|
|
|
}
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
vmulps(ymm_dst, ymm_dst, ymm_z);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_src);
|
|
|
|
|
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vaddps(ymm_dst, ymm_dst, ymm_tmp);
|
|
|
|
|
// build 2^n
|
|
|
|
|
xmm_t ymm_int = ymm_fx;
|
|
|
|
|
vcvttps2dq(ymm_int, ymm_fx);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
|
|
|
|
|
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vpaddd(ymm_int, ymm_int, ymm_tmp);
|
|
|
|
|
vpslld(ymm_int, ymm_int, 23);
|
|
|
|
|
vmulps(ymm_dst, ymm_dst, ymm_int);
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
|
// y = 1 / (1 + e^-x)
|
|
|
|
@ -343,7 +406,7 @@ void VActJitCode::generate() {
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
|
|
|
|
}
|
|
|
|
|
if (type_ != operand_type::relu) {
|
|
|
|
|
if (type_ != operand_type::relu && type_ != operand_type::exp) {
|
|
|
|
|
// TODO(TJ): remove me
|
|
|
|
|
ret();
|
|
|
|
|
return;
|
|
|
|
@ -351,21 +414,50 @@ void VActJitCode::generate() {
|
|
|
|
|
int rest = num_ % YMM_FLOAT_BLOCK;
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
vmovups(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * 4;
|
|
|
|
|
rest -= 4;
|
|
|
|
|
}
|
|
|
|
|
if (rest >= 2) {
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
vmovq(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
vmovq(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * 2;
|
|
|
|
|
rest -= 2;
|
|
|
|
|
}
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
// vmovups();
|
|
|
|
|
vmovss(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
vmovss(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
}
|
|
|
|
|
ret();
|
|
|
|
|