|
|
|
@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
|
|
|
|
|
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
T *gate_value, T *prev_output_value,
|
|
|
|
|
T *output_value, int frame_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
T r_value_update_gate;
|
|
|
|
|
T r_value_frame_state;
|
|
|
|
|
T r_prev_out = 0;
|
|
|
|
@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
|
|
|
|
|
&r_output, active_node);
|
|
|
|
|
&r_output, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
frame_state[i] = r_value_frame_state;
|
|
|
|
|
output_value[i] = r_output;
|
|
|
|
@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
|
|
|
|
|
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
T *gate_value, T *prev_output_value,
|
|
|
|
|
T *output_value, int frame_size,
|
|
|
|
|
ActivationType active_node) {
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
|
|
|
|
|
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
|
|
|
|
@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
|
|
|
|
|
&r_output, active_node);
|
|
|
|
|
&r_output, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
|
|
|
|
|
r_value_frame_state);
|
|
|
|
@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
if (rest > 0) {
|
|
|
|
|
i = n - block;
|
|
|
|
|
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
|
|
|
|
|
&r_prev_out_last, &r_output, active_node);
|
|
|
|
|
&r_prev_out_last, &r_output, active_node, origin_mode);
|
|
|
|
|
|
|
|
|
|
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
|
|
|
|
|
r_value_frame_state_last);
|
|
|
|
@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
|
|
|
|
|
template <class OpFinalOutput, typename T>
|
|
|
|
|
inline void forward_final_output(OpFinalOutput op_final_output,
|
|
|
|
|
GRUMetaValue<T> value, int frame_size,
|
|
|
|
|
int batch_size, ActivationType active_node) {
|
|
|
|
|
int batch_size, ActivationType active_node,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
for (int b = 0; b < batch_size; b++) {
|
|
|
|
|
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
|
|
|
|
|
(sizeof(T) == 4)) {
|
|
|
|
|
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
|
|
|
|
|
value.prev_out_value, value.output_value,
|
|
|
|
|
frame_size, active_node);
|
|
|
|
|
frame_size, active_node, origin_mode);
|
|
|
|
|
} else {
|
|
|
|
|
hl_naive_gru_forward_final_output(
|
|
|
|
|
op_final_output, value.gate_value, value.prev_out_value,
|
|
|
|
|
value.output_value, frame_size, active_node);
|
|
|
|
|
value.output_value, frame_size, active_node, origin_mode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
value.gate_value += frame_size * 3;
|
|
|
|
|