@ -16,6 +16,7 @@ limitations under the License. */
# include "paddle/fluid/framework/eigen.h"
# include "paddle/fluid/framework/tensor.h"
# include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
namespace math {
@ -65,36 +66,42 @@ void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
. broadcast ( one_by_class ) ) ;
}
template < typename DeviceContext , typename T >
class SoftmaxFunctor < DeviceContext , T , true > {
template < typename DeviceContext >
class SoftmaxFunctor < DeviceContext , float , true > {
void operator ( ) ( const DeviceContext & context , const framework : : Tensor * X ,
framework : : Tensor * Y ) {
auto logits = EigenMatrix < T > : : From ( * X ) ;
auto softmax = EigenMatrix < T > : : From ( * Y ) ;
auto in_dims = X - > dims ( ) ;
auto out_dims = Y - > dims ( ) ;
const float * in_data = X - > data < float > ( ) ;
float * out_data = Y - > data < float > ( ) ;
const int kBatchDim = 0 ;
const int kClassDim = 1 ;
const int batch_size = logits . dimension ( kBatchDim ) ;
const int num_classes = logits . dimension ( kClassDim ) ;
Eigen : : DSizes < int , 1 > along_class ( kClassDim ) ;
Eigen : : DSizes < int , 2 > batch_by_one ( batch_size , 1 ) ;
Eigen : : DSizes < int , 2 > one_by_class ( 1 , num_classes ) ;
auto shifted_logits = ( logits -
logits . maximum ( along_class )
. eval ( )
. reshape ( batch_by_one )
. broadcast ( one_by_class ) ) ;
softmax . device ( * context . eigen_device ( ) ) = shifted_logits . exp ( ) ;
softmax . device ( * context . eigen_device ( ) ) = ( softmax *
softmax . sum ( along_class )
. inverse ( )
. eval ( )
. reshape ( batch_by_one )
. broadcast ( one_by_class ) ) ;
// 2D data. Batch x C
const int batch_size = in_dims [ kBatchDim ] ;
const int num_classes = in_dims [ kClassDim ] ;
std : : vector < float > entities ( batch_size ) ;
auto blas = math : : GetBlas < DeviceContext , float > ( context ) ;
for ( int n = 0 ; n < batch_size ; + + n ) {
entities [ n ] = in_data [ n * num_classes ] ;
for ( int c = 1 ; c < num_classes ; + + c ) {
entities [ n ] = in_data [ n * num_classes + c ] > entities [ n ]
? in_data [ n * num_classes + c ]
: entities [ n ] ;
}
for ( int c = 0 ; c < num_classes ; + + c ) {
out_data [ n * num_classes + c ] =
in_data [ n * num_classes + c ] - entities [ n ] ;
}
}
blas . VEXP ( num_classes * batch_size , out_data , out_data ) ;
for ( int n = 0 ; n < batch_size ; + + n ) {
entities [ n ] = out_data [ n * num_classes ] ;
for ( int c = 1 ; c < num_classes ; + + c ) {
entities [ n ] + = out_data [ n * num_classes + c ] ;
}
blas . SCAL ( num_classes , 1.0f / entities [ n ] , & out_data [ n * num_classes ] ) ;
}
}
} ;