|
|
|
@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
|
|
|
|
|
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
T *gate_value, T *prev_output_value,
|
|
|
|
|
T *output_value, int frame_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
T r_value_update_gate;
|
|
|
|
|
T r_value_frame_state;
|
|
|
|
|
T r_prev_out = 0;
|
|
|
|
@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
|
|
|
|
|
&r_output, active_node);
|
|
|
|
|
&r_output, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
frame_state[i] = r_value_frame_state;
|
|
|
|
|
output_value[i] = r_output;
|
|
|
|
@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
|
|
|
|
|
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
T *gate_value, T *prev_output_value,
|
|
|
|
|
T *output_value, int frame_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
|
|
|
|
|
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
|
|
|
|
@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
|
|
|
|
|
&r_output, active_node);
|
|
|
|
|
&r_output, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
|
|
|
|
|
r_value_frame_state);
|
|
|
|
@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
i = n - block;
|
|
|
|
|
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
|
|
|
|
|
&r_prev_out_last, &r_output, active_node);
|
|
|
|
|
&r_prev_out_last, &r_output, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
|
|
|
|
|
r_value_frame_state_last);
|
|
|
|
@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
|
|
|
|
|
template <class OpFinalOutput, typename T>
|
|
|
|
|
inline void forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
GRUMetaValue<T> value, int frame_size,
|
|
|
|
|
int batch_size, ActivationType active_node) {
|
|
|
|
|
int batch_size, ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
for (int b = 0; b < batch_size; b++) {
|
|
|
|
|
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
|
|
|
|
|
(sizeof(T) == 4)) {
|
|
|
|
|
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
|
|
|
|
|
value.prev_out_value, value.output_value,
|
|
|
|
|
frame_size, active_node);
|
|
|
|
|
frame_size, active_node, origin_mode);
|
|
|
|
|
} else {
|
|
|
|
|
hl_naive_gru_forward_final_output(
|
|
|
|
|
op_final_output, value.gate_value, value.prev_out_value,
|
|
|
|
|
value.output_value, frame_size, active_node);
|
|
|
|
|
value.output_value, frame_size, active_node, origin_mode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
value.gate_value += frame_size * 3;
|
|
|
|
@ -253,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
|
|
|
|
|
T *gate_grad, T *prev_out_value,
|
|
|
|
|
T *prev_out_grad, T *output_grad,
|
|
|
|
|
int frame_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
T r_update_gate_value;
|
|
|
|
|
T r_update_gate_grad;
|
|
|
|
|
T r_frame_state_value;
|
|
|
|
@ -279,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
|
|
|
|
|
|
|
|
|
|
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
|
|
|
|
|
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
|
|
|
|
|
&r_prev_out_grad, &r_out_grad, active_node);
|
|
|
|
|
&r_prev_out_grad, &r_out_grad, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
|
|
|
frame_state_grad[i] = r_frame_state_grad;
|
|
|
|
@ -338,8 +342,8 @@ template <class OpStateGrad, typename T>
|
|
|
|
|
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
|
|
|
|
|
T *gate_grad, T *prev_out_value,
|
|
|
|
|
T *prev_out_grad, T *output_grad,
|
|
|
|
|
int frame_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
int frame_size, ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
__m256 r_update_gate_value;
|
|
|
|
|
__m256 r_update_gate_grad;
|
|
|
|
@ -368,7 +372,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
|
|
|
|
|
|
|
|
|
|
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
|
|
|
|
|
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
|
|
|
|
|
&r_prev_out_grad, &r_out_grad, active_node);
|
|
|
|
|
&r_prev_out_grad, &r_out_grad, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
|
|
|
frame_state_grad[i] = r_frame_state_grad;
|
|
|
|
@ -431,16 +435,18 @@ template <class OpStateGrad, typename T>
|
|
|
|
|
inline void backward_state_grad(OpStateGrad op_state_grad,
|
|
|
|
|
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
|
|
|
|
|
int frame_size, int batch_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
ActivationType active_node, bool origin_mode) {
|
|
|
|
|
for (int b = 0; b < batch_size; b++) {
|
|
|
|
|
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
|
|
|
|
|
hl_avx_gru_backward_state_grad(
|
|
|
|
|
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
|
|
|
|
|
hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.output_grad,
|
|
|
|
|
frame_size, active_node, origin_mode);
|
|
|
|
|
} else {
|
|
|
|
|
hl_naive_gru_backward_state_grad(
|
|
|
|
|
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
|
|
|
|
|
hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value,
|
|
|
|
|
grad.prev_out_grad, grad.output_grad,
|
|
|
|
|
frame_size, active_node, origin_mode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
value.gate_value += frame_size * 3;
|
|
|
|
|