|
|
|
@ -298,8 +298,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
|
|
|
|
|
T *gate_grad, T *prev_out_value,
|
|
|
|
|
T *prev_out_grad, T *reset_output_grad,
|
|
|
|
|
int frame_size,
|
|
|
|
|
ActivationType active_gate,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
ActivationType active_gate) {
|
|
|
|
|
T r_update_gate_value;
|
|
|
|
|
T r_update_gate_grad;
|
|
|
|
|
T r_reset_gate_value;
|
|
|
|
@ -329,8 +328,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
|
|
|
|
|
|
|
|
|
|
op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
|
|
|
|
|
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
|
|
|
|
|
&r_prev_out_grad, &r_reset_output_grad, active_gate,
|
|
|
|
|
origin_mode);
|
|
|
|
|
&r_prev_out_grad, &r_reset_output_grad, active_gate);
|
|
|
|
|
|
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
|
|
|
reset_gate_grad[i] = r_reset_gate_grad;
|
|
|
|
@ -389,8 +387,8 @@ template <class OpResetGrad, typename T>
|
|
|
|
|
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
|
|
|
|
|
T *gate_grad, T *prev_out_value,
|
|
|
|
|
T *prev_out_grad, T *reset_output_grad,
|
|
|
|
|
int frame_size, ActivationType active_gate,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
int frame_size,
|
|
|
|
|
ActivationType active_gate) {
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
__m256 r_update_gate_value;
|
|
|
|
|
__m256 r_update_gate_grad;
|
|
|
|
@ -422,8 +420,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
|
|
|
|
|
|
|
|
|
|
op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
|
|
|
|
|
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
|
|
|
|
|
&r_prev_out_grad, &r_reset_output_grad, active_gate,
|
|
|
|
|
origin_mode);
|
|
|
|
|
&r_prev_out_grad, &r_reset_output_grad, active_gate);
|
|
|
|
|
|
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
|
|
|
reset_gate_grad[i] = r_reset_gate_grad;
|
|
|
|
@ -469,18 +466,16 @@ template <class OpResetGrad, typename T>
|
|
|
|
|
inline void backward_reset_grad(OpResetGrad op_reset_grad,
|
|
|
|
|
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
|
|
|
|
|
int frame_size, int batch_size,
|
|
|
|
|
ActivationType active_gate, bool origin_mode) {
|
|
|
|
|
ActivationType active_gate) {
|
|
|
|
|
for (int b = 0; b < batch_size; b++) {
|
|
|
|
|
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
|
|
|
|
|
hl_avx_gru_backward_reset_grad(op_reset_grad, value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.reset_output_grad,
|
|
|
|
|
frame_size, active_gate, origin_mode);
|
|
|
|
|
hl_avx_gru_backward_reset_grad(
|
|
|
|
|
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
|
|
|
|
|
} else {
|
|
|
|
|
hl_naive_gru_backward_reset_grad(
|
|
|
|
|
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate,
|
|
|
|
|
origin_mode);
|
|
|
|
|
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
value.gate_value += frame_size * 3;
|
|
|
|
|