@ -27,17 +27,21 @@ using Tensor = framework::Tensor;
inline void FCOutputSize ( const framework : : DDim & in_dims ,
const framework : : DDim & w_dims ,
std : : vector < int64_t > & out_dims , // NOLINT
int in_num_col_dims ) {
int in_num_col_dims , bool padding_weights ) {
auto in_mat_dims = framework : : flatten_to_2d ( in_dims , in_num_col_dims ) ;
PADDLE_ENFORCE_EQ (
in_mat_dims [ 1 ] , w_dims [ 0 ] ,
" Fully Connected input and weigth size do not match. %s, %s " ) ;
auto w_dims0 = padding_weights ? w_dims [ 0 ] - 4 : w_dims [ 0 ] ;
auto w_dims1 = padding_weights ? w_dims [ 1 ] - 4 : w_dims [ 1 ] ;
PADDLE_ENFORCE_EQ ( in_mat_dims [ 1 ] , w_dims0 ,
platform : : errors : : InvalidArgument (
" Fully Connected input and weigth size do not match. "
" input width: %d,weight height: %d " ,
in_mat_dims [ 1 ] , w_dims0 ) ) ;
out_dims . reserve ( static_cast < size_t > ( in_num_col_dims + 1 ) ) ;
for ( int i = 0 ; i < in_num_col_dims ; + + i ) {
out_dims . push_back ( in_dims [ i ] ) ;
}
out_dims . push_back ( w_dims [ 1] ) ;
out_dims . push_back ( w_dims 1) ;
}
template < typename DeviceContext , typename T >
@ -53,14 +57,18 @@ class FCOpKernel : public framework::OpKernel<T> {
( ctx . Attr < std : : string > ( " activation_type " ) = = " relu " ) ? true : false ;
auto w_dims = w - > dims ( ) ;
bool padding_weights = ctx . Attr < bool > ( " padding_weights " ) ;
std : : vector < int64_t > output_dims ;
FCOutputSize ( input - > dims ( ) , w_dims , output_dims , in_num_col_dims ) ;
FCOutputSize ( input - > dims ( ) , w_dims , output_dims , in_num_col_dims ,
padding_weights ) ;
output - > Resize ( framework : : make_ddim ( output_dims ) ) ;
output - > set_lod ( input - > lod ( ) ) ;
auto out_dims = output - > dims ( ) ;
int M = framework : : product ( out_dims ) / w_dims [ 1 ] ;
auto w_dims0 = padding_weights ? w_dims [ 0 ] - 4 : w_dims [ 0 ] ;
auto w_dims1 = padding_weights ? w_dims [ 1 ] - 4 : w_dims [ 1 ] ;
int M = framework : : product ( out_dims ) / w_dims1 ;
const T * input_data = input - > data < T > ( ) ;
const T * w_data = w - > data < T > ( ) ;
@ -68,8 +76,8 @@ class FCOpKernel : public framework::OpKernel<T> {
auto & dev_ctx = ctx . template device_context < DeviceContext > ( ) ;
math : : FCFunctor < DeviceContext , T > fc ;
fc ( dev_ctx , M , w_dims [ 1] , w_dims [ 0] , input_data , w_data , output_data ,
bias ? bias - > data < T > ( ) : NULL , with_relu );
fc ( dev_ctx , M , w_dims 1, w_dims 0, input_data , w_data , output_data ,
bias ? bias - > data < T > ( ) : NULL , with_relu , padding_weights );
}
} ;