|
|
|
@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VActJitCode::generate() {
|
|
|
|
|
xmm_t xmm_zero = xmm_t(2);
|
|
|
|
|
ymm_t ymm_zero = ymm_t(2);
|
|
|
|
|
if (type_ == operand_type::relu) {
|
|
|
|
|
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
|
|
|
|
}
|
|
|
|
|
int offset = 0;
|
|
|
|
|
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
|
|
|
|
vmovups(ymm_src, ptr[param1 + offset]);
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::identity:
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_dst, ymm_src, type_);
|
|
|
|
|
vmovups(ptr[param2 + offset], ymm_dst);
|
|
|
|
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
|
|
|
|
}
|
|
|
|
@ -182,22 +160,7 @@ void VActJitCode::generate() {
|
|
|
|
|
block = 1;
|
|
|
|
|
vmovss(xmm_src, ptr[param1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
switch (type_) {
|
|
|
|
|
case operand_type::relu:
|
|
|
|
|
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::exp:
|
|
|
|
|
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::sigmoid:
|
|
|
|
|
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
case operand_type::tanh:
|
|
|
|
|
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
act<xmm_t>(xmm_dst, xmm_src, type_);
|
|
|
|
|
if (rest >= 4) {
|
|
|
|
|
vmovups(ptr[param2 + offset], xmm_dst);
|
|
|
|
|
} else if (rest >= 2) {
|
|
|
|
@ -233,52 +196,64 @@ void LSTMJitCode::generate() {
|
|
|
|
|
int offset = 0;
|
|
|
|
|
int d = num_ * sizeof(float);
|
|
|
|
|
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
|
// c
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset]);
|
|
|
|
|
act<ymm_t>(ymm_c, ymm_src, act_cand_);
|
|
|
|
|
// i
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
|
|
|
|
|
if (!compute_c1h1_ && use_peephole_) {
|
|
|
|
|
ymm_t ymm_wp = ymm_t(2);
|
|
|
|
|
ymm_t ymm_ct_1 = ymm_t(3);
|
|
|
|
|
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]);
|
|
|
|
|
/* gates: W_ch, W_ih, W_fh, W_oh */
|
|
|
|
|
ymm_t ymm_c = ymm_t(0);
|
|
|
|
|
ymm_t ymm_i = ymm_t(1);
|
|
|
|
|
ymm_t ymm_f = ymm_t(2);
|
|
|
|
|
ymm_t ymm_o = ymm_t(3);
|
|
|
|
|
ymm_t ymm_ct_1 = ymm_t(4);
|
|
|
|
|
ymm_t ymm_wp0 = ymm_t(5);
|
|
|
|
|
ymm_t ymm_wp1 = ymm_t(6);
|
|
|
|
|
ymm_t ymm_wp2 = ymm_t(7);
|
|
|
|
|
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
|
|
|
|
|
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
|
|
|
|
|
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
|
|
|
|
|
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
|
|
|
|
|
if (!compute_c1h1_) {
|
|
|
|
|
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
|
|
|
|
|
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
|
|
|
|
|
vaddps(ymm_src, ymm_src, ymm_wp);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_i, ymm_src, act_gate_);
|
|
|
|
|
if (use_peephole_) {
|
|
|
|
|
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
|
|
|
|
|
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
|
|
|
|
|
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
|
|
|
|
|
}
|
|
|
|
|
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
|
|
|
|
|
// act_cand(c)
|
|
|
|
|
act<ymm_t>(ymm_c, ymm_c, act_cand_);
|
|
|
|
|
// act_gate(i) or act_gate(ct_1 * wp0 + i)
|
|
|
|
|
if (!compute_c1h1_ && use_peephole_) {
|
|
|
|
|
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
|
|
|
|
|
vaddps(ymm_i, ymm_i, ymm_wp0);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_i, ymm_i, act_gate_);
|
|
|
|
|
vmulps(ymm_c, ymm_c, ymm_i);
|
|
|
|
|
if (!compute_c1h1_) {
|
|
|
|
|
// f
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
|
|
|
|
|
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
|
|
|
|
|
// act_gate(f) or act_gate(ct_1 * wp1 + f)
|
|
|
|
|
if (use_peephole_) {
|
|
|
|
|
ymm_t ymm_wp = ymm_t(3);
|
|
|
|
|
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]);
|
|
|
|
|
vmulps(ymm_wp, ymm_i, ymm_wp);
|
|
|
|
|
vaddps(ymm_src, ymm_src, ymm_wp);
|
|
|
|
|
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
|
|
|
|
|
vaddps(ymm_f, ymm_f, ymm_wp1);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_f, ymm_src, act_gate_);
|
|
|
|
|
vmulps(ymm_f, ymm_f, ymm_i);
|
|
|
|
|
act<ymm_t>(ymm_f, ymm_f, act_gate_);
|
|
|
|
|
// ct
|
|
|
|
|
vmulps(ymm_f, ymm_f, ymm_ct_1);
|
|
|
|
|
vaddps(ymm_f, ymm_f, ymm_c);
|
|
|
|
|
}
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
/* H_t = act_cell(C_t) * act_gate(o) */
|
|
|
|
|
// act_cell(C_t)
|
|
|
|
|
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
|
|
|
|
|
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
|
|
|
|
|
ymm_t ymm_tmp = ymm_i;
|
|
|
|
|
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
|
|
|
|
|
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
|
|
|
|
|
// act_gate(o) or act_gate(ct * wp2 + o)
|
|
|
|
|
if (use_peephole_) {
|
|
|
|
|
ymm_t ymm_wp = ymm_t(2);
|
|
|
|
|
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
|
|
|
|
|
vmulps(ymm_wp, ymm_ct, ymm_wp);
|
|
|
|
|
vaddps(ymm_src, ymm_src, ymm_wp);
|
|
|
|
|
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
|
|
|
|
|
vaddps(ymm_o, ymm_o, ymm_wp2);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
|
|
|
|
|
act<ymm_t>(ymm_o, ymm_src, act_gate_);
|
|
|
|
|
vmulps(ymm_o, ymm_tmp, ymm_o);
|
|
|
|
|
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
|
|
|
|
|
act<ymm_t>(ymm_o, ymm_o, act_gate_);
|
|
|
|
|
// ht
|
|
|
|
|
vmulps(ymm_o, ymm_o, ymm_tmp);
|
|
|
|
|
// save ct and ht
|
|
|
|
|
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
|
|
|
|
|
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
|
|
|
|
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
|
|
|
|
|
|
|
|
|
|
void GRUJitCode::generate() {
|
|
|
|
|
reg64_t reg_ptr_gates = rax;
|
|
|
|
|
reg64_t reg_ptr_ct_1 = r9;
|
|
|
|
|
reg64_t reg_ptr_ct = r10;
|
|
|
|
|
reg64_t reg_ptr_ht = r11;
|
|
|
|
|
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
|
|
|
|
|
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
|
|
|
|
|
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
|
|
|
|
|
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
|
|
|
|
|
reg64_t reg_ptr_ht_1 = r9;
|
|
|
|
|
reg64_t reg_ptr_ht = r10;
|
|
|
|
|
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
|
|
|
|
|
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
|
|
|
|
|
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
|
|
|
|
|
ymm_t ymm_one = ymm_t(0);
|
|
|
|
|
|
|
|
|
|
if (id_ == 2) {
|
|
|
|
|
reg64_t reg_ptr_tmp = r11;
|
|
|
|
|
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
|
|
|
|
|
}
|
|
|
|
|
int offset = 0;
|
|
|
|
|
int d = num_ * sizeof(float);
|
|
|
|
|
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
|
|
|
|
ymm_t ymm_u = ymm_t(1);
|
|
|
|
|
ymm_t ymm_r = ymm_t(2);
|
|
|
|
|
ymm_t ymm_s = ymm_t(3);
|
|
|
|
|
ymm_t ymm_ht_1 = ymm_t(4);
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
if (id_ == 0 || id_ == 2) {
|
|
|
|
|
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
|
|
|
|
|
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
|
|
|
|
|
}
|
|
|
|
|
if (id_ == 1) {
|
|
|
|
|
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
|
|
|
|
|
}
|
|
|
|
|
if (id_ == 1 || id_ == 2) {
|
|
|
|
|
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (id_ == 0) {
|
|
|
|
|
// ht = act_gate(u) * act_cand(s)
|
|
|
|
|
act<ymm_t>(ymm_u, ymm_u, act_gate_);
|
|
|
|
|
act<ymm_t>(ymm_s, ymm_s, act_cand_);
|
|
|
|
|
vmulps(ymm_s, ymm_s, ymm_u);
|
|
|
|
|
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
|
|
|
|
|
} else if (id_ == 1) {
|
|
|
|
|
// ht = act_gate(r) * ht_1
|
|
|
|
|
act<ymm_t>(ymm_r, ymm_r, act_gate_);
|
|
|
|
|
vmulps(ymm_r, ymm_r, ymm_ht_1);
|
|
|
|
|
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
|
|
|
|
|
} else if (id_ == 2) {
|
|
|
|
|
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
|
|
|
|
|
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
|
|
|
|
|
act<ymm_t>(ymm_u, ymm_u, act_gate_);
|
|
|
|
|
act<ymm_t>(ymm_s, ymm_s, act_cand_);
|
|
|
|
|
vmulps(ymm_s, ymm_s, ymm_u);
|
|
|
|
|
vsubps(ymm_u, ymm_one_inner, ymm_u);
|
|
|
|
|
vmulps(ymm_u, ymm_ht_1, ymm_u);
|
|
|
|
|
vaddps(ymm_u, ymm_s, ymm_u);
|
|
|
|
|
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
|
|
|
|
|
}
|
|
|
|
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
|