@ -529,7 +529,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
bool ceil_mode , const MKLDNNMemoryFormat fmt ,
mkldnn : : memory : : data_type dt , bool is_test ,
const platform : : MKLDNNDeviceContext & dev_ctx , platform : : Place cpu_place ,
const std : : string & unique_name )
const std : : string & unique_name , bool exclude_padding )
: platform : : MKLDNNHandlerT < T , mkldnn : : pooling_forward ,
mkldnn : : pooling_backward > (
dev_ctx , dev_ctx . GetEngine ( ) , cpu_place ,
@ -553,8 +553,11 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
this - > AcquireForwardPrimitiveDescriptor (
is_test ? mkldnn : : prop_kind : : forward_inference
: mkldnn : : prop_kind : : forward_training ,
pooling_type = = " max " ? mkldnn : : algorithm : : pooling_max
: mkldnn : : algorithm : : pooling_avg ,
pooling_type = = " max "
? mkldnn : : algorithm : : pooling_max
: ( exclude_padding
? mkldnn : : algorithm : : pooling_avg_exclude_padding
: mkldnn : : algorithm : : pooling_avg_include_padding ) ,
src_md , dst_md , strides , ksize , padding_left_top , padding_right_bottom ,
mkldnn : : padding_kind : : zero ) ;
}
@ -567,7 +570,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
const MKLDNNMemoryFormat fmt , const MKLDNNMemoryFormat diff_dst_fmt ,
mkldnn : : memory : : data_type dt ,
const platform : : MKLDNNDeviceContext & dev_ctx , platform : : Place cpu_place ,
const std : : string & unique_name )
const std : : string & unique_name , bool exclude_padding )
: platform : : MKLDNNHandlerT < T , mkldnn : : pooling_forward ,
mkldnn : : pooling_backward > (
dev_ctx , dev_ctx . GetEngine ( ) , cpu_place ,
@ -580,8 +583,11 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
MKLDNNMemoryFormat : : any ) ;
this - > AcquireBackwardPrimitiveDescriptor (
pooling_type = = " max " ? mkldnn : : algorithm : : pooling_max
: mkldnn : : algorithm : : pooling_avg ,
pooling_type = = " max "
? mkldnn : : algorithm : : pooling_max
: ( exclude_padding
? mkldnn : : algorithm : : pooling_avg_exclude_padding
: mkldnn : : algorithm : : pooling_avg_include_padding ) ,
diff_src_md , diff_dst_md , strides , ksize , paddings , paddings ,
mkldnn : : padding_kind : : zero ) ;
}