|
|
|
@ -177,14 +177,6 @@ bool VActJitCode::init(int d, operand_type type) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
|
|
|
|
|
vmaxps(ymm_dst, ymm_zero, ymm_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
|
|
|
|
|
vmaxps(xmm_dst, xmm_zero, xmm_src);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
|
|
|
|
|
int fy_idx, int mask_idx, int tmp_idx) {
|
|
|
|
|
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
@ -378,7 +370,7 @@ void VActJitCode::generate() {
|
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_ymm(ymm_dst, ymm_src, ymm_zero);
|
|
|
|
|
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
@ -414,7 +406,7 @@ void VActJitCode::generate() {
|
|
|
|
|
}
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_xmm(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|