@ -27,6 +27,7 @@ namespace paddle {
namespace platform {
using user_function = std : : function < std : : shared_ptr < float > ( const float * ) > ;
using memory = mkldnn : : memory ;
class MKLDNNHandler {
public :
@ -196,21 +197,6 @@ class MKLDNNHandler {
return dims2str ( operand_dims ) + suffix ;
}
template < typename T >
static void SetDstMemory (
const framework : : ExecutionContext & ctx , framework : : Tensor * output ,
std : : vector < int > dst_tz , const mkldnn : : engine & engine ,
std : : shared_ptr < mkldnn : : memory : : primitive_desc > & dst_pd , // NOLINT
std : : shared_ptr < mkldnn : : memory > & dst_memory ) { // NOLINT
T * output_data = output - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
auto dst_md = platform : : MKLDNNMemDesc (
{ dst_tz } , paddle : : framework : : ToMKLDNNDataType (
framework : : DataTypeTrait < T > : : DataType ) ,
mkldnn : : memory : : format : : nhwc ) ;
dst_pd . reset ( new mkldnn : : memory : : primitive_desc ( dst_md , engine ) ) ;
dst_memory . reset ( new mkldnn : : memory ( * dst_pd , to_void_cast < T > ( output_data ) ) ) ;
}
static void AppendKey (
std : : string * key , const mkldnn : : memory : : dims & input_dims ,
const mkldnn : : memory : : dims & weights_dims , const std : : vector < int > & strides ,
@ -915,5 +901,26 @@ static void SetDstMemoryHandler(
( * dst_memory_p ) - > set_data_handle ( to_void_cast < T > ( output_data ) ) ;
}
template < typename T >
static void SetDstMemoryQuantized (
const framework : : ExecutionContext & ctx , framework : : Tensor * output ,
std : : vector < int > dst_tz , const mkldnn : : engine & engine ,
std : : shared_ptr < mkldnn : : memory : : primitive_desc > & dst_pd , // NOLINT
std : : shared_ptr < mkldnn : : memory > & dst_memory ) { // NOLINT
T * output_data = output - > mutable_data < T > ( ctx . GetPlace ( ) ) ;
const size_t dst_dims = dst_tz . size ( ) ;
memory : : format dst_fmt ;
PADDLE_ENFORCE ( dst_dims < = 5 ,
" Dst memory for quantization can not have dims > 5 " ) ;
dst_fmt = platform : : MKLDNNFormatForSize ( dst_dims , memory : : format : : nhwc ) ;
auto dst_md = platform : : MKLDNNMemDesc (
{ dst_tz } , paddle : : framework : : ToMKLDNNDataType (
framework : : DataTypeTrait < T > : : DataType ) ,
dst_fmt ) ;
dst_pd . reset ( new mkldnn : : memory : : primitive_desc ( dst_md , engine ) ) ;
dst_memory . reset ( new mkldnn : : memory ( * dst_pd , to_void_cast < T > ( output_data ) ) ) ;
}
} // namespace platform
} // namespace paddle