@ -33,12 +33,15 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework : : SelectedRows ;
using DDim = framework : : DDim ;
constexpr int64_t kNoPadding = - 1 ;
# if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
! defined ( __OSX__ ) & & ! defined ( PADDLE_WITH_CUDA )
! defined ( __OSX__ )
template < typename T >
void prepare_csr_data ( const std : : vector < uint64_t > & offset ,
const int64_t * ids_data , const size_t idx_width ,
T * csr_vals , int * csr_colmuns , int * csr_row_idx ) {
T * csr_vals , int * csr_colmuns , int * csr_row_idx ,
int64_t padding_idx = kNoPadding ) {
int val_idx = 0 ;
int row_idx = 0 ;
csr_row_idx [ 0 ] = 0 ;
@ -52,10 +55,12 @@ void prepare_csr_data(const std::vector<uint64_t> &offset,
// construct a map for creating csr
for ( size_t j = offset [ i ] ; j < offset [ i + 1 ] ; + + j ) {
unsigned int word_idx =
static_cast < unsigned int > ( ids_data [ idx + j * idx_width ] ) ;
auto ids_value = ids_data [ idx + j * idx_width ] ;
if ( ids_value ! = padding_idx ) {
unsigned int word_idx = static_cast < unsigned int > ( ids_value ) ;
+ + ids_map [ word_idx ] ;
}
}
VLOG ( 4 ) < < " ====sequence %d==== " < < i ;
for ( std : : map < int , int > : : const_iterator it = ids_map . begin ( ) ;
@ -124,7 +129,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
FusedEmbeddingSeqPoolLastDim ( table_var - > dims ( ) , ids_t - > dims ( ) ) ;
const auto & ids_lod = ids_t - > lod ( ) ;
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE ( ids_lod . size ( ) , 1UL ,
PADDLE_ENFORCE _EQ ( ids_lod . size ( ) , 1UL ,
" The LoD level of Input(Ids) must be 1 " ) ;
int64_t batch_size = ids_lod [ 0 ] . size ( ) - 1 ;
// in run time, the shape from Ids -> output
@ -133,7 +138,8 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
if ( combiner_type = = " sum " ) {
# if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
! defined ( __OSX__ ) & & ! defined ( PADDLE_WITH_CUDA )
! defined ( __OSX__ )
int64_t padding_idx = context . Attr < int64_t > ( " padding_idx " ) ;
auto output = output_t - > mutable_data < T > ( context . GetPlace ( ) ) ;
int64_t table_height = table_var - > dims ( ) [ 0 ] ;
int64_t table_width = table_var - > dims ( ) [ 1 ] ;
@ -151,7 +157,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
auto csr_colmuns = csr_colmuns_t . mutable_data < int > ( context . GetPlace ( ) ) ;
auto csr_row_idx = csr_row_idx_t . mutable_data < int > ( context . GetPlace ( ) ) ;
prepare_csr_data < T > ( offset , ids_t - > data < int64_t > ( ) , idx_width , csr_vals ,
csr_colmuns , csr_row_idx );
csr_colmuns , csr_row_idx , padding_idx );
const char transa = ' N ' ;
const T alpha = 1.0 ;
@ -226,17 +232,18 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
}
} else {
# if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
! defined ( __OSX__ ) & & ! defined ( PADDLE_WITH_CUDA )
! defined ( __OSX__ )
auto * ids = context . Input < LoDTensor > ( " Ids " ) ;
auto * d_output = context . Input < LoDTensor > ( framework : : GradVarName ( " Out " ) ) ;
auto * d_table = context . Output < LoDTensor > ( framework : : GradVarName ( " W " ) ) ;
int64_t padding_idx = context . Attr < int64_t > ( " padding_idx " ) ;
d_table - > Resize ( table_dim ) ;
auto * d_table_data = d_table - > mutable_data < T > ( context . GetPlace ( ) ) ;
memset ( d_table_data , 0 , d_table - > numel ( ) * sizeof ( T ) ) ;
const auto & ids_lod = ids - > lod ( ) ;
PADDLE_ENFORCE ( ids_lod . size ( ) , 1UL ,
PADDLE_ENFORCE _EQ ( ids_lod . size ( ) , 1UL ,
" The LoD level of Input(Ids) must be 1 " ) ;
const std : : vector < uint64_t > offset = ids_lod [ 0 ] ;
auto len = ids - > numel ( ) ;
@ -251,23 +258,21 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto csr_colmuns = csr_colmuns_t . mutable_data < int > ( context . GetPlace ( ) ) ;
auto csr_row_idx = csr_row_idx_t . mutable_data < int > ( context . GetPlace ( ) ) ;
prepare_csr_data < T > ( offset , ids - > data < int64_t > ( ) , idx_width , csr_vals ,
csr_colmuns , csr_row_idx );
csr_colmuns , csr_row_idx , padding_idx );
auto * d_output_data = d_output - > data < T > ( ) ;
const char transa = ' T ' ;
const T alpha = 1.0 ;
const T beta = 0.0 ;
const char matdescra [ ] = { ' G ' , ' L ' , ' N ' , ' C ' } ;
const int m = batch_size * idx_width ;
const int n = table_dim [ 1 ] ;
const int k = table_dim [ 1 ] ;
auto blas = math : : GetBlas < platform : : CPUDeviceContext , T > ( context ) ;
blas . CSRMM ( & transa , & m , & n , & k , & alpha , matdescra , ( const T * ) csr_vals ,
( const int * ) csr_colmuns , ( const int * ) csr_row_idx ,
( const int * ) csr_row_idx + 1 , d_output_data , & n , & beta ,
d_table_data , & n ) ;
int width = static_cast < int > ( table_dim [ 1 ] ) ;
int num_seq = batch_size * idx_width ;
LOG ( INFO ) < < " num seq = " < < num_seq < < " width = " < < width ;
for ( int i = 0 ; i < num_seq ; + + i ) {
for ( int j = csr_row_idx [ i ] ; j < csr_row_idx [ i + 1 ] ; + + j ) {
unsigned int word_idx = csr_colmuns [ j ] ;
T val = csr_vals [ j ] ;
blas . AXPY ( width , val , d_output_data + i * width ,
d_table_data + word_idx * width ) ;
}
}
# else
LOG ( ERROR ) < < " Dense is not supported in fused_embedding_seq_pool_op now " ;
# endif