@ -288,6 +288,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output - > set_layout ( DataLayout : : kMKLDNN ) ;
output - > set_format ( GetMKLDNNFormat ( * dst_memory_p ) ) ;
}
void ComputeINT8 ( const paddle : : framework : : ExecutionContext & ctx ) const {
const bool is_test = ctx . Attr < bool > ( " is_test " ) ;
@ -325,7 +326,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_relu = ctx . Attr < bool > ( " fuse_relu " ) ;
bool fuse_residual_conn = ctx . Attr < bool > ( " fuse_residual_connection " ) ;
bool fuse_brelu = ctx . Attr < bool > ( " fuse_brelu " ) ;
float fuse_brelu_threshold = ctx . Attr < float > ( " fuse_brelu_threshold " ) ;
bool force_fp32_output = ctx . Attr < bool > ( " force_fp32_output " ) ;
bool unsigned_output = fuse_relu | | fuse_brelu ;
if ( fuse_residual_conn ) {
PADDLE_ENFORCE ( force_fp32_output ! = true ,
" residual fusion does not support force output with fp32 " ) ;
@ -340,8 +343,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
" dilation in convolution is not implemented yet " ) ;
PADDLE_ENFORCE ( is_conv3d ! = true , " int8 does not support conv3d currently " ) ;
PADDLE_ENFORCE ( fuse_brelu ! = true ,
" int8 does not support conv/relu6 fusion currently " ) ;
const T * input_data = input - > data < T > ( ) ;
@ -356,10 +357,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn : : memory : : data_type src_dt =
paddle : : framework : : ToMKLDNNDataType ( input - > type ( ) ) ;
auto dst_dt = ( fuse_relu ) ? paddle : : framework : : ToMKLDNNDataType (
framework : : DataTypeTrait < uint8_t > : : DataType )
: paddle : : framework : : ToMKLDNNDataType (
framework : : DataTypeTrait < int8_t > : : DataType ) ;
auto dst_dt = unsigned_output
? paddle : : framework : : ToMKLDNNDataType (
framework : : DataTypeTrait < uint8_t > : : DataType )
: paddle : : framework : : ToMKLDNNDataType (
framework : : DataTypeTrait < int8_t > : : DataType ) ;
if ( force_fp32_output ) {
dst_dt = paddle : : framework : : ToMKLDNNDataType (
@ -377,13 +379,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key . reserve ( MaxKeyLength ) ;
platform : : ConvMKLDNNHandler : : AppendKey (
& key , src_tz , weights_tz , strides , paddings , dilations , groups , src_dt ,
input - > format ( ) , fuse_relu , fuse_residual_conn , false /*fuse_brelu*/ ,
input - > format ( ) , fuse_relu , fuse_residual_conn , fuse_brelu ,
ctx . op ( ) . Input ( " Input " ) + ctx . op ( ) . Input ( " Filter " ) ) ;
const std : : string key_conv_pd = key + " @conv_pd " ;
bool need_s8_to_u8 = false ;
std : : shared_ptr < mkldnn : : convolution_forward > conv_p = nullptr ;
std : : shared_ptr < mkldnn : : memory > src_memory_p = nullptr ;
std : : shared_ptr < mkldnn : : memory > user_src_memory_p = nullptr ;
@ -456,6 +457,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform : : MKLDNNMemDesc ( dst_tz , dst_dt , chosen_memory_format ) ;
// create a conv primitive descriptor and save it for usage in backward
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled
if ( bias ) {
bias_tz = paddle : : framework : : vectorize2int ( bias - > dims ( ) ) ;
auto bias_md = platform : : MKLDNNMemDesc ( bias_tz , memory : : data_type : : s32 ,
@ -463,16 +467,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_pd = ConvFwdPrimitiveDesc (
src_md , weights_md , bias_md , dst_md , strides , paddings ,
mkldnn_engine , fuse_relu , fuse_residual_conn , false /*fuse_b relu*/,
0.0 /*fuse_brelu_threshold*/ , output_shift_scale , sum_scale ,
is_test) ;
mkldnn_engine , fuse_relu | | fuse_brelu /*fuse_ relu*/,
fuse_residual_conn , false /*fuse_brelu*/ , fuse_brelu_threshold ,
output_shift_scale, sum_scale , is_test) ;
} else {
conv_pd = ConvFwdPrimitiveDesc ( src_md , weights_md , dst_md , strides ,
paddings , mkldnn_engine , fuse_relu ,
fuse_residual_conn , false /*fuse_brelu*/ ,
0.0 /*fuse_brelu_threshold*/ ,
output_shift_scale , sum_scale , is_test ) ;
conv_pd = ConvFwdPrimitiveDesc (
src_md , weights_md , dst_md , strides , paddings , mkldnn_engine ,
fuse_relu | | fuse_brelu /*fuse_relu*/ , fuse_residual_conn ,
false /*fuse_brelu*/ , fuse_brelu_threshold , output_shift_scale ,
sum_scale , is_test ) ;
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx . SetBlob ( key_conv_pd , conv_pd ) ;
@ -514,7 +518,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ctx , output , residual_param , user_residual_md , handler ,
& pipeline ) ;
} else {
need_s8_to_u8 = fuse_relu ;
need_s8_to_u8 = unsigned_output ;
dst_memory_p = platform : : SetDstMemory < int8_t > (
ctx , output , residual_param , user_residual_md , handler ,
& pipeline ) ;
@ -525,12 +529,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p =
platform : : SetDstMemory < uint8_t > ( ctx , output , handler ) ;
} else {
need_s8_to_u8 = fuse_relu ;
need_s8_to_u8 = unsigned_output ;
dst_memory_p = platform : : SetDstMemory < int8_t > ( ctx , output , handler ) ;
}
}
} else if ( ! force_fp32_output ) {
if ( fuse_relu ) {
if ( unsigned_output ) {
dst_memory_p = platform : : SetDstMemory < uint8_t > ( ctx , output , handler ) ;
} else {
dst_memory_p = platform : : SetDstMemory < int8_t > ( ctx , output , handler ) ;
@ -602,12 +606,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform : : SetDstMemoryHandler < uint8_t > ( ctx , output , handler ,
& dst_memory_p ) ;
} else {
need_s8_to_u8 = fuse_relu ;
need_s8_to_u8 = unsigned_output ;
platform : : SetDstMemoryHandler < int8_t > ( ctx , output , handler ,
& dst_memory_p ) ;
}
} else if ( ! force_fp32_output ) {
if ( fuse_relu ) {
if ( unsigned_output ) {
platform : : SetDstMemoryHandler < uint8_t > ( ctx , output , handler ,
& dst_memory_p ) ;
} else {