@ -232,40 +232,28 @@ use lstm_x_t as input and compute as standard LSTM.
template < typename T >
inline void bias_relu ( const int n , const T * x , const T * bias , T * y ) {
if ( bias ) {
for ( int i = 0 ; i < n ; + + i ) {
y [ i ] = x [ i ] + bias [ 0 ] ;
}
math : : vec_relu < T > ( n , y , y ) ;
math : : vec_add_bias < T , platform : : jit : : avx > ( n , * bias , x , y ) ;
math : : vec_relu < T , platform : : jit : : avx > ( n , y , y ) ;
} else {
math : : vec_relu < T > ( n , x , y ) ;
math : : vec_relu < T , platform : : jit : : avx > ( n , x , y ) ;
}
}
template < typename DeviceContext , typename T >
inline void vec_softmax ( const math : : BlasT < DeviceContext , T > & blas , const int n ,
const T * x , T * y ) {
template < typename T >
inline void vec_softmax ( const int n , const T * x , T * y ) {
T scalar = x [ 0 ] ;
// max
for ( int i = 1 ; i < n ; + + i ) {
scalar = scalar < x [ i ] ? x [ i ] : scalar ;
}
// sub
for ( int i = 0 ; i < n ; + + i ) {
y [ i ] = x [ i ] - scalar ;
}
// exp
blas . VEXP ( n , y , y ) ;
math : : vec_add_bias < T , platform : : jit : : avx > ( n , - scalar , x , y ) ; // sub
math : : vec_exp < T > ( n , y , y ) ; // exp
// sum
scalar = T ( 0 ) ;
for ( int i = 0 ; i < n ; + + i ) {
scalar + = y [ i ] ;
}
// scale
blas . SCAL ( n , static_cast < T > ( 1 ) / scalar , y ) ;
math : : vec_scal < T > ( n , static_cast < T > ( 1 ) / scalar , y ) ; // scale
}
template < typename T >
@ -311,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ ( c0 - > dims ( ) [ 0 ] , N , " C0 dims should be %d x %d. " , N , D ) ;
fc_out - > Resize ( { max_seq_len , 1 } ) ;
math : : VecActivations < T > act_functor ;
std : : function < void ( const int , const T * , T * ) > act_gate , act_cell , act_cand ;
act_gate = act_functor ( ctx . Attr < std : : string > ( " gate_activation " ) ) ;
act_cell = act_functor ( ctx . Attr < std : : string > ( " cell_activation " ) ) ;
act_cand = act_functor ( ctx . Attr < std : : string > ( " candidate_activation " ) ) ;
auto & act_gate_str = ctx . Attr < std : : string > ( " gate_activation " ) ;
auto & act_cell_str = ctx . Attr < std : : string > ( " cell_activation " ) ;
auto & act_cand_str = ctx . Attr < std : : string > ( " candidate_activation " ) ;
if ( platform : : jit : : MayIUse ( platform : : jit : : avx ) ) {
math : : VecActivations < T , platform : : jit : : avx > act_functor ;
act_gate = act_functor ( act_gate_str ) ;
act_cell = act_functor ( act_cell_str ) ;
act_cand = act_functor ( act_cand_str ) ;
} else {
math : : VecActivations < T , platform : : jit : : isa_any > act_functor ;
act_gate = act_functor ( act_gate_str ) ;
act_cell = act_functor ( act_cell_str ) ;
act_cand = act_functor ( act_cand_str ) ;
}
const T * x_data = x - > data < T > ( ) ;
const T * h0_data = h0 ? h0 - > data < T > ( ) : NULL ;
@ -363,7 +361,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
fc_out_data ) ;
}
// 1d. softmax
vec_softmax < DeviceContext, T> ( blas , seq_len , fc_out_data , fc_out_data ) ;
vec_softmax < T> ( seq_len , fc_out_data , fc_out_data ) ;
// mul x(seq_len*M) and sum pool
math : : FCCompute < DeviceContext , T > ( blas , 1 , M , seq_len , fc_out_data ,
cur_x_data , lstm_x_data ) ;