@ -85,26 +85,59 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T * prev_output_value , int frame_size ,
T * prev_output_value , int frame_size ,
ActivationType active_gate ) {
ActivationType active_gate ) {
# ifdef __AVX__
# ifdef __AVX__
__m256 r_value_update_gate ;
__m256 r_value_update_gate , r_value_update_gate_last = _mm256_set1_ps ( 0.0f ) ;
__m256 r_value_reset_gate ;
__m256 r_value_reset_gate , r_value_reset_gate_last = _mm256_set1_ps ( 0.0f ) ;
__m256 r_value_reset_output ;
__m256 r_value_reset_output ;
__m256 r_prev_out = _mm256_set1_ps ( 0.0f ) ;
__m256 r_prev_out = _mm256_set1_ps ( 0.0f ) ,
__m256 * update_gate = reinterpret_cast < __m256 * > ( gate_value ) ;
r_prev_out_last = _mm256_set1_ps ( 0.0f ) ;
__m256 * reset_gate = reinterpret_cast < __m256 * > ( gate_value + frame_size ) ;
T * update_gate = gate_value ;
T * reset_gate = gate_value + frame_size ;
int block = 8 ;
const int n = frame_size ;
const int rest = n % block ;
const int end = n - rest ;
int i = 0 ;
if ( rest > 0 ) {
i = n - block ;
r_value_update_gate_last =
_mm256_loadu_ps ( ( const float * ) ( update_gate + i ) ) ;
r_value_reset_gate_last = _mm256_loadu_ps ( ( const float * ) ( reset_gate + i ) ) ;
if ( prev_output_value ) {
r_prev_out_last = _mm256_loadu_ps ( ( const float * ) ( prev_output_value + i ) ) ;
}
}
for ( int i = 0 ; i < frame_size / 8 ; i + + ) {
for ( i = 0 ; i < end; i + = block ) {
r_value_update_gate = update_gate [ i ] ;
r_value_update_gate = _mm256_loadu_ps ( ( const float * ) ( update_gate + i ) ) ;
r_value_reset_gate = reset_gate [ i ] ;
r_value_reset_gate = _mm256_loadu_ps ( ( const float * ) ( reset_gate + i ) ) ;
if ( prev_output_value ) {
if ( prev_output_value ) {
r_prev_out = ( reinterpret_cast < __m256 * > ( prev_output_value ) ) [ i ] ;
r_prev_out = _mm256_loadu_ps ( ( const float * ) ( prev_output_value + i ) ) ;
}
}
op_reset_output ( & r_value_update_gate , & r_value_reset_gate , & r_prev_out ,
op_reset_output ( & r_value_update_gate , & r_value_reset_gate , & r_prev_out ,
& r_value_reset_output , active_gate ) ;
& r_value_reset_output , active_gate ) ;
update_gate [ i ] = r_value_update_gate ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( update_gate + i ) ,
reset_gate [ i ] = r_value_reset_gate ;
r_value_update_gate ) ;
( reinterpret_cast < __m256 * > ( reset_output_value ) ) [ i ] = r_value_reset_output ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( reset_gate + i ) ,
r_value_reset_gate ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( reset_output_value + i ) ,
r_value_reset_output ) ;
}
if ( rest > 0 ) {
i = n - block ;
op_reset_output ( & r_value_update_gate_last , & r_value_reset_gate_last ,
& r_prev_out_last , & r_value_reset_output , active_gate ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( update_gate + i ) ,
r_value_update_gate_last ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( reset_gate + i ) ,
r_value_reset_gate_last ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( reset_output_value + i ) ,
r_value_reset_output ) ;
}
}
# endif
# endif
}
}
@ -115,26 +148,55 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T * output_value , int frame_size ,
T * output_value , int frame_size ,
ActivationType active_node ) {
ActivationType active_node ) {
# ifdef __AVX__
# ifdef __AVX__
__m256 r_value_update_gate ;
__m256 r_value_update_gate , r_value_update_gate_last = _mm256_set1_ps ( 0.0f ) ;
__m256 r_value_frame_state ;
__m256 r_value_frame_state , r_value_frame_state_last = _mm256_set1_ps ( 0.0f ) ;
__m256 r_prev_out = _mm256_set1_ps ( 0.0f ) ;
__m256 r_prev_out = _mm256_set1_ps ( 0.0f ) ,
r_prev_out_last = _mm256_set1_ps ( 0.0f ) ;
__m256 r_output ;
__m256 r_output ;
__m256 * update_gate = reinterpret_cast < __m256 * > ( gate_value ) ;
T * update_gate = gate_value ;
__m256 * frame_state = reinterpret_cast < __m256 * > ( gate_value + frame_size * 2 ) ;
T * frame_state = gate_value + frame_size * 2 ;
int block = 8 ;
const int n = frame_size ;
const int rest = n % block ;
const int end = n - rest ;
int i = 0 ;
if ( rest > 0 ) {
i = n - block ;
r_value_update_gate_last =
_mm256_loadu_ps ( ( const float * ) ( update_gate + i ) ) ;
r_value_frame_state_last =
_mm256_loadu_ps ( ( const float * ) ( frame_state + i ) ) ;
if ( prev_output_value ) {
r_prev_out_last = _mm256_loadu_ps ( ( const float * ) ( prev_output_value + i ) ) ;
}
}
for ( int i = 0 ; i < frame_size / 8 ; i + + ) {
for ( i = 0 ; i < end; i + = block ) {
r_value_update_gate = update_gate [ i ] ;
r_value_update_gate = _mm256_loadu_ps ( ( const float * ) ( update_gate + i ) ) ;
r_value_frame_state = frame_state [ i ] ;
r_value_frame_state = _mm256_loadu_ps ( ( const float * ) ( frame_state + i ) ) ;
if ( prev_output_value ) {
if ( prev_output_value ) {
r_prev_out = ( reinterpret_cast < __m256 * > ( prev_output_value ) ) [ i ] ;
r_prev_out = _mm256_loadu_ps ( ( const float * ) ( prev_output_value + i ) ) ;
}
}
op_final_output ( & r_value_update_gate , & r_value_frame_state , & r_prev_out ,
op_final_output ( & r_value_update_gate , & r_value_frame_state , & r_prev_out ,
& r_output , active_node ) ;
& r_output , active_node ) ;
frame_state [ i ] = r_value_frame_state ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( frame_state + i ) ,
( reinterpret_cast < __m256 * > ( output_value ) ) [ i ] = r_output ;
r_value_frame_state ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( output_value + i ) , r_output ) ;
}
if ( rest > 0 ) {
i = n - block ;
op_final_output ( & r_value_update_gate_last , & r_value_frame_state_last ,
& r_prev_out_last , & r_output , active_node ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( frame_state + i ) ,
r_value_frame_state_last ) ;
_mm256_storeu_ps ( reinterpret_cast < float * > ( output_value + i ) , r_output ) ;
}
}
# endif
# endif
}
}
@ -143,7 +205,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
GRUMetaValue < T > value , int frame_size ,
GRUMetaValue < T > value , int frame_size ,
int batch_size , ActivationType active_gate ) {
int batch_size , ActivationType active_gate ) {
for ( int b = 0 ; b < batch_size ; b + + ) {
for ( int b = 0 ; b < batch_size ; b + + ) {
if ( OpResetOutput : : avx & & ! ( frame_size & ( 8 - 1 ) ) & & ( sizeof ( T ) = = 4 ) ) {
if ( OpResetOutput : : avx & & ( frame_size > static_cast < int > ( 8 - 1 ) ) & &
( sizeof ( T ) = = 4 ) ) {
hl_avx_gru_forward_reset_output (
hl_avx_gru_forward_reset_output (
op_reset_output , value . gate_value , value . reset_output_value ,
op_reset_output , value . gate_value , value . reset_output_value ,
value . prev_out_value , frame_size , active_gate ) ;
value . prev_out_value , frame_size , active_gate ) ;
@ -166,7 +229,8 @@ inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue < T > value , int frame_size ,
GRUMetaValue < T > value , int frame_size ,
int batch_size , ActivationType active_node ) {
int batch_size , ActivationType active_node ) {
for ( int b = 0 ; b < batch_size ; b + + ) {
for ( int b = 0 ; b < batch_size ; b + + ) {
if ( OpFinalOutput : : avx & & ! ( frame_size & ( 8 - 1 ) ) & & ( sizeof ( T ) = = 4 ) ) {
if ( OpFinalOutput : : avx & & ( frame_size > static_cast < int > ( 8 - 1 ) ) & &
( sizeof ( T ) = = 4 ) ) {
hl_avx_gru_forward_final_output ( op_final_output , value . gate_value ,
hl_avx_gru_forward_final_output ( op_final_output , value . gate_value ,
value . prev_out_value , value . output_value ,
value . prev_out_value , value . output_value ,
frame_size , active_node ) ;
frame_size , active_node ) ;