@ -31,7 +31,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen : : DenseIndex >
using EigenMatrix = framework : : EigenMatrix < T , MajorType , IndexType > ;
template < typename T >
template < typename T , bool is_test >
class MaxSeqPoolFunctor {
public :
void operator ( ) ( const platform : : CPUDeviceContext & context ,
@ -70,7 +70,41 @@ class MaxSeqPoolFunctor {
}
}
} ;
// Instantisation of Max Sequence Pooling for test phase eg. no need to fill
// index buffer
template < typename T >
class MaxSeqPoolFunctor < T , true > {
public :
void operator ( ) ( const platform : : CPUDeviceContext & context ,
const framework : : LoDTensor & input , framework : : Tensor * output ,
framework : : Tensor * index ) {
auto in_dims = input . dims ( ) ;
auto out_dims = output - > dims ( ) ;
PADDLE_ENFORCE_GT ( in_dims . size ( ) , 1 ) ;
PADDLE_ENFORCE_GT ( out_dims . size ( ) , 1 ) ;
for ( int64_t i = 1 ; i < in_dims . size ( ) ; + + i ) {
PADDLE_ENFORCE_EQ ( in_dims [ i ] , out_dims [ i ] ) ;
}
auto starts = input . lod ( ) [ 0 ] ;
const T * in_data = input . data < T > ( ) ;
T * out_data = output - > data < T > ( ) ;
int64_t num_seq = out_dims [ 0 ] ;
int64_t dim = output - > numel ( ) / num_seq ;
for ( int64_t i = 0 ; i < num_seq ; + + i ) {
std : : memcpy ( & out_data [ i * dim ] , & in_data [ starts [ i ] * dim ] ,
dim * sizeof ( T ) ) ;
for ( size_t j = starts [ i ] + 1 ; j < starts [ i + 1 ] ; + + j ) {
for ( int64_t k = 0 ; k < dim ; + + k ) {
if ( in_data [ j * dim + k ] > out_data [ i * dim + k ] ) {
out_data [ i * dim + k ] = in_data [ j * dim + k ] ;
}
}
}
}
}
} ;
template < typename T >
class MaxSeqPoolGradFunctor {
public :
@ -188,11 +222,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
/* max pool has index output */
void operator ( ) ( const platform : : CPUDeviceContext & context ,
const std : : string pooltype , const framework : : LoDTensor & input ,
framework : : Tensor * output ,
framework : : Tensor * output , bool is_test ,
framework : : Tensor * index = nullptr ) {
if ( pooltype = = " MAX " ) {
math : : MaxSeqPoolFunctor < T > max_pool ;
max_pool ( context , input , output , index ) ;
if ( is_test ) {
math : : MaxSeqPoolFunctor < T , true > max_pool ;
max_pool ( context , input , output , index ) ;
} else {
math : : MaxSeqPoolFunctor < T , false > max_pool ;
max_pool ( context , input , output , index ) ;
}
return ;
}
if ( pooltype = = " LAST " ) {
@ -200,6 +239,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
last_pool ( context , input , output ) ;
return ;
}
if ( pooltype = = " FIRST " ) {
math : : FirstSeqPoolFunctor < T > first_pool ;
first_pool ( context , input , output ) ;