@ -225,7 +225,52 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
vbroadcast ( src , dst , h , out_width ) ;
}
} else {
# if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
! defined ( __OSX__ ) & & ! defined ( PADDLE_WITH_CUDA )
auto * ids = context . Input < LoDTensor > ( " Ids " ) ;
auto * d_output = context . Input < LoDTensor > ( framework : : GradVarName ( " Out " ) ) ;
auto * d_table = context . Output < LoDTensor > ( framework : : GradVarName ( " W " ) ) ;
d_table - > Resize ( table_dim ) ;
auto * d_table_data = d_table - > mutable_data < T > ( context . GetPlace ( ) ) ;
memset ( d_table_data , 0 , d_table - > numel ( ) * sizeof ( T ) ) ;
const auto & ids_lod = ids - > lod ( ) ;
PADDLE_ENFORCE ( ids_lod . size ( ) , 1UL ,
" The LoD level of Input(Ids) must be 1 " ) ;
const std : : vector < uint64_t > offset = ids_lod [ 0 ] ;
auto len = ids - > numel ( ) ;
int idx_width = len / offset . back ( ) ;
Tensor csr_vals_t , csr_colmuns_t , csr_row_idx_t ;
csr_vals_t . Resize ( { len } ) ;
csr_colmuns_t . Resize ( { len } ) ;
int64_t batch_size = ids_lod [ 0 ] . size ( ) - 1 ;
csr_row_idx_t . Resize ( { ( batch_size + 1 ) * idx_width } ) ;
auto csr_vals = csr_vals_t . mutable_data < T > ( context . GetPlace ( ) ) ;
auto csr_colmuns = csr_colmuns_t . mutable_data < int > ( context . GetPlace ( ) ) ;
auto csr_row_idx = csr_row_idx_t . mutable_data < int > ( context . GetPlace ( ) ) ;
prepare_csr_data < T > ( offset , ids - > data < int64_t > ( ) , idx_width , csr_vals ,
csr_colmuns , csr_row_idx ) ;
auto * d_output_data = d_output - > data < T > ( ) ;
const char transa = ' T ' ;
const T alpha = 1.0 ;
const T beta = 0.0 ;
const char matdescra [ ] = { ' G ' , ' L ' , ' N ' , ' C ' } ;
const int m = batch_size * idx_width ;
const int n = table_dim [ 1 ] ;
const int k = table_dim [ 1 ] ;
auto blas = math : : GetBlas < platform : : CPUDeviceContext , T > ( context ) ;
blas . CSRMM ( & transa , & m , & n , & k , & alpha , matdescra , ( const T * ) csr_vals ,
( const int * ) csr_colmuns , ( const int * ) csr_row_idx ,
( const int * ) csr_row_idx + 1 , d_output_data , & n , & beta ,
d_table_data , & n ) ;
# else
LOG ( ERROR ) < < " Dense is not supported in fused_embedding_seq_pool_op now " ;
# endif
}
}
} ;