|
|
|
@ -221,10 +221,14 @@ void LSTMJitCode::generate() {
|
|
|
|
|
reg64_t reg_ptr_ct_1 = r9;
|
|
|
|
|
reg64_t reg_ptr_ct = r10;
|
|
|
|
|
reg64_t reg_ptr_ht = r11;
|
|
|
|
|
reg64_t reg_ptr_wp = r12;
|
|
|
|
|
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
|
|
|
|
|
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
|
|
|
|
|
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
|
|
|
|
|
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
|
|
|
|
|
if (use_peephole_) {
|
|
|
|
|
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int offset = 0;
|
|
|
|
|
int d = num_ * sizeof(float);
|
|
|
|
@ -235,13 +239,27 @@ void LSTMJitCode::generate() {
|
|
|
|
|
act<ymm_t>(ymm_c, ymm_src, act_cand_);
|
|
|
|
|
// i
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
|
|
|
|
|
if (!compute_c1h1_ && use_peephole_) {
|
|
|
|
|
ymm_t ymm_wp = ymm_t(2);
|
|
|
|
|
ymm_t ymm_ct_1 = ymm_t(3);
|
|
|
|
|
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]);
|
|
|
|
|
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
|
|
|
|
|
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
|
|
|
|
|
vaddps(ymm_src, ymm_src, ymm_wp);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_i, ymm_src, act_gate_);
|
|
|
|
|
vmulps(ymm_c, ymm_c, ymm_i);
|
|
|
|
|
if (!compute_c1h1_) {
|
|
|
|
|
// f
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
|
|
|
|
|
act<ymm_t>(ymm_f, ymm_src, act_gate_);
|
|
|
|
|
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
|
|
|
|
|
if (use_peephole_) {
|
|
|
|
|
ymm_t ymm_wp = ymm_t(3);
|
|
|
|
|
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]);
|
|
|
|
|
vmulps(ymm_wp, ymm_i, ymm_wp);
|
|
|
|
|
vaddps(ymm_src, ymm_src, ymm_wp);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_f, ymm_src, act_gate_);
|
|
|
|
|
vmulps(ymm_f, ymm_f, ymm_i);
|
|
|
|
|
vaddps(ymm_f, ymm_f, ymm_c);
|
|
|
|
|
}
|
|
|
|
@ -250,8 +268,14 @@ void LSTMJitCode::generate() {
|
|
|
|
|
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
|
|
|
|
|
ymm_t ymm_tmp = ymm_i;
|
|
|
|
|
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
|
|
|
|
|
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
|
|
|
|
|
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
|
|
|
|
|
if (use_peephole_) {
|
|
|
|
|
ymm_t ymm_wp = ymm_t(2);
|
|
|
|
|
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
|
|
|
|
|
vmulps(ymm_wp, ymm_ct, ymm_wp);
|
|
|
|
|
vaddps(ymm_src, ymm_src, ymm_wp);
|
|
|
|
|
}
|
|
|
|
|
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
|
|
|
|
|
act<ymm_t>(ymm_o, ymm_src, act_gate_);
|
|
|
|
|
vmulps(ymm_o, ymm_tmp, ymm_o);
|
|
|
|
|
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
|
|
|
|
|