@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std : : string key_conv_pd = key + " @conv_pd " ;
bool need_s8_to_u8 = false ;
std : : shared_ptr < mkldnn : : convolution_forward > conv_p = nullptr ;
std : : shared_ptr < mkldnn : : memory > src_memory_p = nullptr ;
std : : shared_ptr < mkldnn : : memory > user_src_memory_p = nullptr ;
std : : shared_ptr < mkldnn : : memory > dst_memory_p = nullptr ;
std : : shared_ptr < mkldnn : : convolution_forward > conv_p ;
std : : shared_ptr < mkldnn : : memory > src_memory_p ;
std : : shared_ptr < mkldnn : : memory > user_src_memory_p ;
std : : shared_ptr < mkldnn : : memory > dst_memory_p ;
std : : vector < primitive > pipeline ;
std : : shared_ptr < mkldnn : : convolution_forward : : primitive_desc > conv_pd =
nullptr ;
std : : shared_ptr < platform : : ConvMKLDNNHandler > handler = nullptr ;
std : : shared_ptr < mkldnn : : convolution_forward : : primitive_desc > conv_pd ;
std : : shared_ptr < platform : : ConvMKLDNNHandler > handler ;
auto prim_key = key + " @conv_p " ;
auto dst_key = key + " @dst_mem_p " ;
@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled
std : : shared_ptr < memory : : desc > bias_md_p ;
if ( bias ) {
bias_tz = paddle : : framework : : vectorize2int ( bias - > dims ( ) ) ;
auto bias_md = platform : : MKLDNNMemDesc ( bias_tz , memory : : data_type : : s32 ,
memory : : format : : x ) ;
conv_pd = ConvFwdPrimitiveDesc (
src_md , weights_md , bias_md , dst_md , strides , paddings ,
mkldnn_engine , fuse_relu | | fuse_brelu /*fuse_relu*/ ,
fuse_residual_conn , false /*fuse_brelu*/ , fuse_brelu_threshold ,
output_shift_scale , sum_scale , is_test ) ;
} else {
conv_pd = ConvFwdPrimitiveDesc (
src_md , weights_md , dst_md , strides , paddings , mkldnn_engine ,
fuse_relu | | fuse_brelu /*fuse_relu*/ , fuse_residual_conn ,
false /*fuse_brelu*/ , fuse_brelu_threshold , output_shift_scale ,
sum_scale , is_test ) ;
bias_md_p = std : : make_shared < memory : : desc > ( platform : : MKLDNNMemDesc (
bias_tz , memory : : data_type : : s32 , memory : : format : : x ) ) ;
}
conv_pd = ConvFwdPrimitiveDesc (
src_md , weights_md , bias_md_p , dst_md , strides , paddings ,
mkldnn_engine , fuse_relu | | fuse_brelu /*fuse_relu*/ ,
fuse_residual_conn , false /*fuse_brelu*/ , fuse_brelu_threshold ,
output_shift_scale , sum_scale , is_test ) ;
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx . SetBlob ( key_conv_pd , conv_pd ) ;
handler . reset ( new platform : : ConvMKLDNNHandler ( conv_pd , dev_ctx ,
@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private :
mkldnn : : primitive_attr CreatePostOps (
bool fuse_relu , bool fuse_residual_conn ,
const std : : vector < float > output_shift_scale , float sum_scale ,
const std : : vector < float > & output_shift_scale , float sum_scale ,
bool fuse_brelu , float fuse_brelu_threshold ) const {
mkldnn : : primitive_attr conv_attr ;
mkldnn : : post_ops post_operations ;
@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std : : unique_ptr < mkldnn : : convolution_forward : : primitive_desc >
ConvFwdPrimitiveDesc ( const memory : : desc & src , const memory : : desc & weights ,
const std : : shared_ptr < memory : : desc > bias_md_p ,
const memory : : desc & dst , const std : : vector < int > & strides ,
const std : : vector < int > & paddings ,
const mkldnn : : engine & engine , const bool fuse_relu ,
const bool fuse_residual_conn , const bool fuse_brelu ,
const float fuse_brelu_threshold ,
const std : : vector < float > output_shift_scale ,
const std : : vector < float > & output_shift_scale ,
const float sum_scale , bool is_test ) const {
memory : : dims stride_dims = { strides [ 0 ] , strides [ 1 ] } ;
memory : : dims padding_dims = { paddings [ 0 ] , paddings [ 1 ] } ;
auto propagation = is_test ? mkldnn : : prop_kind : : forward_scoring
: mkldnn : : prop_kind : : forward_training ;
auto conv_desc = mkldnn : : convolution_forward : : desc (
propagation , mkldnn : : convolution_direct , src , weights , dst , stride_dims ,
padding_dims , padding_dims , mkldnn : : padding_kind : : zero ) ;
mkldnn : : primitive_attr conv_attr =
CreatePostOps ( fuse_relu , fuse_residual_conn , output_shift_scale ,
sum_scale , fuse_brelu , fuse_brelu_threshold ) ;
auto p_conv_pd = new mkldnn : : convolution_forward : : primitive_desc (
conv_desc , conv_attr , engine ) ;
return std : : unique_ptr < mkldnn : : convolution_forward : : primitive_desc > (
p_conv_pd ) ;
}
std : : unique_ptr < mkldnn : : convolution_forward : : primitive_desc >
ConvFwdPrimitiveDesc ( const memory : : desc & src , const memory : : desc & weights ,
const memory : : desc & bias , const memory : : desc & dst ,
const std : : vector < int > & strides ,
const std : : vector < int > & paddings ,
const mkldnn : : engine & engine , const bool fuse_relu ,
const bool fuse_residual_conn , const bool fuse_brelu ,
const float fuse_brelu_threshold ,
const std : : vector < float > output_shift_scale ,
const float sum_scale , bool is_test ) const {
memory : : dims stride_dims = { strides [ 0 ] , strides [ 1 ] } ;
memory : : dims padding_dims = { paddings [ 0 ] , paddings [ 1 ] } ;
auto propagation = is_test ? mkldnn : : prop_kind : : forward_scoring
: mkldnn : : prop_kind : : forward_training ;
auto conv_desc = mkldnn : : convolution_forward : : desc (
propagation , mkldnn : : convolution_direct , src , weights , bias , dst ,
stride_dims , padding_dims , padding_dims , mkldnn : : padding_kind : : zero ) ;
auto conv_desc =
( bias_md_p ! = nullptr )
? mkldnn : : convolution_forward : : desc (
propagation , mkldnn : : convolution_direct , src , weights ,
( * bias_md_p ) , dst , stride_dims , padding_dims , padding_dims ,
mkldnn : : padding_kind : : zero )
: mkldnn : : convolution_forward : : desc (
propagation , mkldnn : : convolution_direct , src , weights , dst ,
stride_dims , padding_dims , padding_dims ,
mkldnn : : padding_kind : : zero ) ;
mkldnn : : primitive_attr conv_attr =
CreatePostOps ( fuse_relu , fuse_residual_conn , output_shift_scale ,