@ -87,7 +87,7 @@ template <typename DeviceContext, typename T>
class ContextProjectFunctor {
class ContextProjectFunctor {
public :
public :
void operator ( ) ( const DeviceContext & context , const LoDTensor & in ,
void operator ( ) ( const DeviceContext & context , const LoDTensor & in ,
const Tensor & padding_data , bool padding_trainable ,
const Tensor * padding_data , bool padding_trainable ,
const int context_start , const int context_length ,
const int context_start , const int context_length ,
const int context_stride , const int up_pad ,
const int context_stride , const int up_pad ,
const int down_pad , Tensor * col ) {
const int down_pad , Tensor * col ) {
@ -132,6 +132,7 @@ class ContextProjectFunctor {
}
}
}
}
if ( padding_trainable ) {
if ( padding_trainable ) {
PADDLE_ENFORCE_NOT_NULL ( padding_data ) ;
for ( int i = 0 ; i < static_cast < int > ( lod_level_0 . size ( ) ) - 1 ; + + i ) {
for ( int i = 0 ; i < static_cast < int > ( lod_level_0 . size ( ) ) - 1 ; + + i ) {
Tensor out_t = col - > Slice ( static_cast < int > ( lod_level_0 [ i ] ) ,
Tensor out_t = col - > Slice ( static_cast < int > ( lod_level_0 [ i ] ) ,
static_cast < int > ( lod_level_0 [ i + 1 ] ) ) ;
static_cast < int > ( lod_level_0 [ i + 1 ] ) ) ;
@ -150,7 +151,7 @@ class ContextProjectFunctor {
k + context_length < up_pad ? context_length : up_pad - k ;
k + context_length < up_pad ? context_length : up_pad - k ;
Tensor out_t_sub = out_t . Slice ( k * context_length ,
Tensor out_t_sub = out_t . Slice ( k * context_length ,
k * context_length + padding_size ) ;
k * context_length + padding_size ) ;
Tensor w_sub = padding_data . Slice ( k , k + padding_size ) ;
Tensor w_sub = padding_data - > Slice ( k , k + padding_size ) ;
framework : : TensorCopy ( w_sub , context . GetPlace ( ) , context ,
framework : : TensorCopy ( w_sub , context . GetPlace ( ) , context ,
& out_t_sub ) ;
& out_t_sub ) ;
}
}
@ -180,7 +181,7 @@ class ContextProjectFunctor {
Tensor out_t_sub = out_t . Slice (
Tensor out_t_sub = out_t . Slice (
( down_pad_begin_row + t ) * context_length - padding_size ,
( down_pad_begin_row + t ) * context_length - padding_size ,
( down_pad_begin_row + t ) * context_length ) ;
( down_pad_begin_row + t ) * context_length ) ;
Tensor w_sub = padding_data . Slice (
Tensor w_sub = padding_data - > Slice (
up_pad + padding_idx , up_pad + padding_idx + padding_size ) ;
up_pad + padding_idx , up_pad + padding_idx + padding_size ) ;
framework : : TensorCopy ( w_sub , context . GetPlace ( ) , context ,
framework : : TensorCopy ( w_sub , context . GetPlace ( ) , context ,
& out_t_sub ) ;
& out_t_sub ) ;