@ -54,6 +54,25 @@ class FCPrimitiveFactory {
return ;
} // Otherwise, create a new one.
auto in_col_dims = ctx . Attr < int > ( " in_num_col_dims " ) ;
PADDLE_ENFORCE_LE ( in_col_dims , 2 ,
platform : : errors : : Unimplemented (
" DNNL FC doesn't support in_num_col_dims paramter to "
" be higher than "
" 2. " ) ) ;
if ( in_col_dims = = 2 ) {
PADDLE_ENFORCE_EQ (
input - > dims ( ) . size ( ) , 3 ,
platform : : errors : : Unimplemented (
" DNNL FC only supports in_num_col_dims equal to 2 when "
" 3 dim input is provided. " ) ) ;
PADDLE_ENFORCE_EQ (
input - > format ( ) , MKLDNNMemoryFormat : : ncw ,
platform : : errors : : Unimplemented (
" DNNL FC only supports in_num_col_dims equal to 2 when "
" input format is equal to ncw. " ) ) ;
}
// Transform weights to default MKL-DNN format
weights_ = TransposeWeights ( weights ) ;
// Since MKL-DNN has a lot of limitations on what the input/weights/output
@ -121,6 +140,33 @@ class FCPrimitiveFactory {
}
private :
// DNNL always returns 2-dimensional data block as a result of computing
// inner product. Hence the format 'nc' is always set for its output
// primitive. Therefore, function SetOutputFormat is needed to choose
// an appropriate format based on the number of input dimensions and
// format of an input tensor.
void SetOutputFormat ( MKLDNNMemoryFormat in_format , Tensor * out ) {
int dim_num = out - > dims ( ) . size ( ) ;
// In case of 2 dims, we set the only possible format, nc
if ( dim_num = = 2 ) {
out - > set_format ( MKLDNNMemoryFormat : : nc ) ;
// In case of 3 dims, we generate a format that is based on number
// of output dims and the layout of input format (nchw or nhwc).
} else if ( dim_num = = 3 ) {
if ( in_format = = MKLDNNMemoryFormat : : nwc | |
in_format = = MKLDNNMemoryFormat : : nhwc ) {
out - > set_format (
platform : : MKLDNNFormatForSize ( dim_num , MKLDNNMemoryFormat : : nhwc ) ) ;
} else {
out - > set_format (
platform : : MKLDNNFormatForSize ( dim_num , MKLDNNMemoryFormat : : nchw ) ) ;
}
// In any other case we overwrite the output format with the input one.
} else {
out - > set_format ( in_format ) ;
}
}
void UpdateDataPointers ( const ExecutionContext & ctx , Tensor * out ,
const Tensor * in ) {
input_ - > set_data_handle ( to_void_cast ( in - > data < T_in > ( ) ) ) ;
@ -129,17 +175,7 @@ class FCPrimitiveFactory {
// variable, update its format to what has been determined in first
// call to CreateFcPrimitive method.
if ( out - > format ( ) = = MKLDNNMemoryFormat : : undef ) {
MKLDNNMemoryFormat format ;
auto data_type = input_ - > get_desc ( ) . data . data_type ;
if ( data_type = = mkldnn_f32 )
format = MKLDNNMemoryFormat : : nchw ;
else
format = MKLDNNMemoryFormat : : nhwc ;
MKLDNNMemoryFormat selected = platform : : MKLDNNFormatForSize (
framework : : vectorize < int > ( out - > dims ( ) ) . size ( ) , format ) ;
out - > set_format ( selected ) ;
SetOutputFormat ( in - > format ( ) , out ) ;
}
}
@ -168,8 +204,8 @@ class FCPrimitiveFactory {
const LoDTensor * input , const Tensor * weights , const Tensor * bias ,
LoDTensor * output , const ExecutionContext & ctx ) {
auto input_dims = framework : : vectorize ( input - > dims ( ) ) ;
std : : vector < int64_t > new_input_dims = { input_dims [ 0 ] * input_dims [ 1 ] , 1 ,
input_dims [ 2 ] };
std : : vector < int64_t > new_input_dims = { input_dims [ 0 ] * input_dims [ 1 ] ,
input_dims [ 2 ] , 1 };
auto src_desc = CreateMemDescriptor < T_in > ( new_input_dims , input - > format ( ) ) ;
auto weight_dims = Get3DWeightDimsForDNNL ( weights ) ;
@ -187,7 +223,7 @@ class FCPrimitiveFactory {
std : : vector < int64_t > Get3DWeightDimsForDNNL ( const Tensor * weights ) {
auto paddle_w_dims = framework : : vectorize ( weights - > dims ( ) ) ;
return { paddle_w_dims [ 1 ] , 1 , paddle_w_dims [ 0 ] } ;
return { paddle_w_dims [ 1 ] , paddle_w_dims [ 0 ] , 1 } ;
}
memory : : desc Create3DUserWeightsDesc ( const Tensor * weights ) {
@ -405,18 +441,8 @@ class FCPrimitiveFactory {
T_out * output_data =
output - > mutable_data < T_out > ( ctx . GetPlace ( ) , buffer_size ) ;
memory dst_mem ( dst_desc , engine_ , to_void_cast < T_out > ( output_data ) ) ;
SetOutputFormat ( ctx . Input < LoDTensor > ( " Input " ) - > format ( ) , output ) ;
MKLDNNMemoryFormat format ;
auto data_type = input_ - > get_desc ( ) . data . data_type ;
if ( data_type = = mkldnn_f32 )
format = MKLDNNMemoryFormat : : nchw ;
else
format = MKLDNNMemoryFormat : : nhwc ;
MKLDNNMemoryFormat selected = platform : : MKLDNNFormatForSize (
framework : : vectorize < int > ( output - > dims ( ) ) . size ( ) , format ) ;
output - > set_format ( selected ) ;
return dst_mem ;
}