|
|
|
@ -159,8 +159,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
|
|
|
|
|
T *gate_grad, T *prev_out_value,
|
|
|
|
|
T *prev_out_grad, T *reset_output_grad,
|
|
|
|
|
int frame_size, int batch_size,
|
|
|
|
|
ActivationType active_gate,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
ActivationType active_gate) {
|
|
|
|
|
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (frame_idx >= frame_size) return;
|
|
|
|
|
int batch_idx = 0;
|
|
|
|
@ -190,7 +189,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
|
|
|
|
|
|
|
|
|
|
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
|
|
|
|
|
&r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
|
|
|
|
|
&r_reset_output_grad, active_gate, origin_mode);
|
|
|
|
|
&r_reset_output_grad, active_gate);
|
|
|
|
|
|
|
|
|
|
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
|
|
|
|
|
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
|
|
|
|
|