@ -61,6 +61,19 @@ inline void get_mid_dims(const framework::DDim& x_dims,
}
}
inline void trim_trailing_singular_dims ( framework : : DDim & dims ) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims . size ( ) ;
for ( ; actual_dims_size ! = 0 ; - - actual_dims_size ) {
if ( dims [ actual_dims_size - 1 ] ! = 1 ) break ;
}
if ( actual_dims_size ! = dims . size ( ) ) {
auto actual_dims = framework : : vectorize ( dims ) ;
actual_dims . resize ( actual_dims_size ) ;
dims = framework : : make_ddim ( actual_dims ) ;
}
}
template < typename T , typename DeviceContext >
class RowwiseTransformIterator ;
template < typename T , typename DeviceContext >
@ -263,44 +276,6 @@ class TransformFunctor {
} \
}
template < class functor , typename DeviceContext , typename T >
void ElementwiseCompute ( const framework : : ExecutionContext & ctx ) {
using Tensor = framework : : Tensor ;
auto * x = ctx . Input < Tensor > ( " X " ) ;
auto * y = ctx . Input < Tensor > ( " Y " ) ;
auto * z = ctx . Output < Tensor > ( " Out " ) ;
z - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
auto x_dims = x - > dims ( ) ;
auto y_dims = y - > dims ( ) ;
PADDLE_ENFORCE_GE ( x_dims . size ( ) , y_dims . size ( ) ,
" Rank of first input must >= rank of second input. " ) ;
if ( x_dims = = y_dims ) {
functor f ;
f . template Run < DeviceContext , T > ( x , y , z , ctx ) ;
return ;
}
int axis = ctx . Attr < int > ( " axis " ) ;
axis = ( axis = = - 1 ? x_dims . size ( ) - y_dims . size ( ) : axis ) ;
PADDLE_ENFORCE ( axis > = 0 & & axis < x_dims . size ( ) ,
" Axis should be in range [0, x_dims) " ) ;
int pre , n , post ;
get_mid_dims ( x_dims , y_dims , axis , pre , n , post ) ;
if ( post = = 1 ) {
functor f ;
f . template RunBroadCast < DeviceContext , T > ( x , y , z , ctx , pre , n ) ;
return ;
} else {
functor f ;
f . template RunBroadCast2 < DeviceContext , T > ( x , y , z , ctx , pre , n , post ) ;
return ;
}
}
# define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR ( Add , EIGEN_ADD ) ;
@ -516,14 +491,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto x_dim = x . dims ( ) ;
auto y_dim = y . dims ( ) ;
if ( y_dim . size ( ) = = 1 & & y_dim [ 0 ] = = 1 ) {
// y is a scalar
auto extended_dims = framework : : vectorize ( x_dim ) ;
extended_dims . push_back ( 1 ) ;
x_dim = framework : : make_ddim ( extended_dims ) ;
}
axis = ( axis = = - 1 ? x_dim . size ( ) - y_dim . size ( ) : axis ) ;
trim_trailing_singular_dims ( y_dim ) ;
axis = ( y_dim . size ( ) = = 0 ) ? x_dim . size ( ) : axis ;
int pre , n , post ;
get_mid_dims ( x_dim , y_dim , axis , pre , n , post ) ;
if ( post = = 1 ) {
@ -591,14 +562,9 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
return ;
}
if ( y_dims . size ( ) = = 1 & & y_dims [ 0 ] = = 1 ) {
// y is a scalar
auto extended_dims = framework : : vectorize ( x_dims ) ;
extended_dims . push_back ( 1 ) ;
x_dims = framework : : make_ddim ( extended_dims ) ;
}
axis = ( axis = = - 1 ? x_dims . size ( ) - y_dims . size ( ) : axis ) ;
trim_trailing_singular_dims ( y_dims ) ;
axis = ( y_dims . size ( ) = = 0 ) ? x_dims . size ( ) : axis ;
int pre , n , post ;
get_mid_dims ( x_dims , y_dims , axis , pre , n , post ) ;
@ -633,16 +599,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
return ;
}
if ( y_dims . size ( ) = = 1 & & y_dims [ 0 ] = = 1 ) {
// y is a scalar
auto extended_dims = framework : : vectorize ( x_dims ) ;
extended_dims . push_back ( 1 ) ;
x_dims = framework : : make_ddim ( extended_dims ) ;
}
axis = ( axis = = - 1 ? x_dims . size ( ) - y_dims . size ( ) : axis ) ;
PADDLE_ENFORCE ( axis > = 0 & & axis < x_dims . size ( ) ,
" Axis should be in range [0, x_dims) " ) ;
trim_trailing_singular_dims ( y_dims ) ;
axis = ( y_dims . size ( ) = = 0 ) ? x_dims . size ( ) : axis ;
int pre , n , post ;
get_mid_dims ( x_dims , y_dims , axis , pre , n , post ) ;