|
|
|
@ -60,60 +60,53 @@ void VXXJitCode::generate() {
|
|
|
|
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
|
|
|
|
}
|
|
|
|
|
int rest = num_ % YMM_FLOAT_BLOCK;
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovups(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovups(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (type_ == operand_type::mul) {
|
|
|
|
|
vmulps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
} else if (type_ == operand_type::add) {
|
|
|
|
|
vaddps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
}
|
|
|
|
|
if (with_relu_) {
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_dst);
|
|
|
|
|
}
|
|
|
|
|
vmovups(ptr[param3 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * 4;
|
|
|
|
|
rest -= 4;
|
|
|
|
|
}
|
|
|
|
|
if (rest >= 2) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovq(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovq(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
int block = XMM_FLOAT_BLOCK;
|
|
|
|
|
while (rest > 0) {
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovups(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovups(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
} else if (rest >= 2) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovq(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovq(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovss(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovss(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (type_ == operand_type::mul) {
|
|
|
|
|
vmulps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
} else if (type_ == operand_type::add) {
|
|
|
|
|
vaddps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::mul:
|
|
|
|
|
vmulps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::add:
|
|
|
|
|
vaddps(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (with_relu_) {
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_dst);
|
|
|
|
|
}
|
|
|
|
|
vmovq(ptr[param3 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * 2;
|
|
|
|
|
rest -= 2;
|
|
|
|
|
}
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
if (scalar_index_ != 1) {
|
|
|
|
|
vmovss(xmm_src1, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (scalar_index_ != 2) {
|
|
|
|
|
vmovss(xmm_src2, ptr[param2 + offset]);
|
|
|
|
|
}
|
|
|
|
|
if (type_ == operand_type::mul) {
|
|
|
|
|
vmulss(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
} else if (type_ == operand_type::add) {
|
|
|
|
|
vaddss(xmm_dst, xmm_src1, xmm_src2);
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
vmovups(ptr[param3 + offset], xmm_dst);
|
|
|
|
|
} else if (rest >= 2) {
|
|
|
|
|
vmovq(ptr[param3 + offset], xmm_dst);
|
|
|
|
|
} else {
|
|
|
|
|
vmovss(ptr[param3 + offset], xmm_dst);
|
|
|
|
|
}
|
|
|
|
|
if (with_relu_) {
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_dst);
|
|
|
|
|
}
|
|
|
|
|
vmovss(ptr[param3 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * block;
|
|
|
|
|
rest -= block;
|
|
|
|
|
block /= 2;
|
|
|
|
|
}
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
|
|
|
|
|
|
|
|
|
|
bool VActJitCode::init(int d, operand_type type) {
|
|
|
|
|
bool ok = MayIUse(avx);
|
|
|
|
|
if (type == operand_type::relu) {
|
|
|
|
|
if (type == operand_type::relu || type == operand_type::exp) {
|
|
|
|
|
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
|
|
|
|
|
return ok;
|
|
|
|
|
} else if (type == operand_type::exp) {
|
|
|
|
|
// exp is slower than mkl when d >= 256
|
|
|
|
|
return ok; //&& d % 4 == 0 && d < 256;
|
|
|
|
|
} else {
|
|
|
|
|
// TODO(TJ): support more
|
|
|
|
|
return ok && d % 8 == 0;
|
|
|
|
@ -412,24 +403,15 @@ void VActJitCode::generate() {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
int rest = num_ % YMM_FLOAT_BLOCK;
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
int block = XMM_FLOAT_BLOCK;
|
|
|
|
|
while (rest > 0) {
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
vmovups(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
} else if (rest >= 2) {
|
|
|
|
|
vmovq(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
} else {
|
|
|
|
|
vmovss(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
vmovups(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * 4;
|
|
|
|
|
rest -= 4;
|
|
|
|
|
}
|
|
|
|
|
if (rest >= 2) {
|
|
|
|
|
vmovq(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
@ -440,25 +422,16 @@ void VActJitCode::generate() {
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
vmovq(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * 2;
|
|
|
|
|
rest -= 2;
|
|
|
|
|
}
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
// vmovups();
|
|
|
|
|
vmovss(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
vmovups(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
} else if (rest >= 2) {
|
|
|
|
|
vmovq(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
} else {
|
|
|
|
|
vmovss(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
}
|
|
|
|
|
vmovss(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
offset += sizeof(float) * block;
|
|
|
|
|
rest -= block;
|
|
|
|
|
block /= 2;
|
|
|
|
|
}
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
|