|  |  | @ -21,6 +21,7 @@ limitations under the License. */ | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "boost/optional.hpp" |  |  |  | #include "boost/optional.hpp" | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/fluid/framework/data_layout_transform.h" |  |  |  | #include "paddle/fluid/framework/data_layout_transform.h" | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/fluid/framework/operator.h" |  |  |  | #include "paddle/fluid/framework/operator.h" | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | #include "paddle/fluid/operators/pool_op.h" | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/fluid/platform/mkldnn_helper.h" |  |  |  | #include "paddle/fluid/platform/mkldnn_helper.h" | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/fluid/platform/place.h" |  |  |  | #include "paddle/fluid/platform/place.h" | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
	
		
		
			
				
					|  |  | @ -592,41 +593,100 @@ template <typename T> | 
			
		
	
		
		
			
				
					
					|  |  |  | class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, |  |  |  | class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                    mkldnn::pooling_backward> { |  |  |  |                                                    mkldnn::pooling_backward> { | 
			
		
	
		
		
			
				
					
					|  |  |  |  public: |  |  |  |  public: | 
			
		
	
		
		
			
				
					
					|  |  |  |   PoolingMKLDNNHandler( |  |  |  |   PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |       const std::vector<int64_t>& src_dims, |  |  |  |                        const MKLDNNDeviceContext& dev_ctx, | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |       const std::vector<int64_t>& dst_dims, const std::vector<int64_t>& ksize, |  |  |  |                        const mkldnn::engine mkldnn_engine, | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |       const std::vector<int64_t>& strides, const std::vector<int64_t>& paddings, |  |  |  |                        platform::Place cpu_place, const Tensor* input, | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |       const std::string& pooling_type, bool ceil_mode, |  |  |  |                        Tensor* output, const std::string& unique_name) | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |       const MKLDNNMemoryFormat fmt, mkldnn::memory::data_type dt, bool is_test, |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |       const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |       const std::string& unique_name, bool exclude_padding) |  |  |  |  | 
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |       : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, |  |  |  |       : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                  mkldnn::pooling_backward>( |  |  |  |                                  mkldnn::pooling_backward>( | 
			
		
	
		
		
			
				
					
					|  |  |  |             dev_ctx, dev_ctx.GetEngine(), cpu_place, |  |  |  |             dev_ctx, dev_ctx.GetEngine(), cpu_place, | 
			
		
	
		
		
			
				
					
					|  |  |  |             platform::CreateKey(src_dims, dt, unique_name)) { |  |  |  |             platform::CreateKey(framework::vectorize(input->dims()), | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     auto src_md = mkldnn::memory::desc(src_dims, dt, fmt); |  |  |  |                                 framework::ToMKLDNNDataType(input->type()), | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     /* create memory descriptor for pooling without specified format
 |  |  |  |                                 unique_name)) { | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |      * ('any') which lets a primitive (pooling in this case) choose |  |  |  |     if (!this->isCached()) { | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |      * the memory format preferred for best performance |  |  |  |       PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |      */ |  |  |  |                         platform::errors::InvalidArgument( | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     auto dst_md = |  |  |  |                             "Wrong layout set for Input tensor")); | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |         platform::MKLDNNMemDesc(dst_dims, dt, MKLDNNMemoryFormat::any); |  |  |  |       PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef, | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         platform::errors::InvalidArgument( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                             "Wrong format set for Input tensor")); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const std::string pooling_type = ctx.Attr<std::string>("pooling_type"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp)); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       std::vector<int64_t> strides(begin(strides_temp), end(strides_temp)); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp)); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const bool global_pooling = ctx.Attr<bool>("global_pooling"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const std::string padding_algorithm = | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           ctx.Attr<std::string>("padding_algorithm"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       // Only 2D pooling is supported now
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       PADDLE_ENFORCE_EQ(ksize.size(), 2, | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         platform::errors::InvalidArgument( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                             "ksize must be 2D, i.e. 2D pooling")); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true, | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         platform::errors::InvalidArgument( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                             "pooling_type must be 'max' or 'avg'")); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       PADDLE_ENFORCE_EQ(input->dims().size(), 4, | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         platform::errors::InvalidArgument( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                             "Input dim must be with 4, i.e. NCHW")); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto input_dims = input->dims(); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       framework::DDim data_dims = | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           framework::slice_ddim(input_dims, 2, input_dims.size()); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       if (global_pooling) { | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         operators::UpdateKsize(&ksize, data_dims); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       } | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     auto mkldnn_paddings = ToMkldnnPadding(paddings); |  |  |  |       operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                data_dims, strides, ksize); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto src_tz = paddle::framework::vectorize(input->dims()); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto dst_tz = paddle::framework::vectorize(output->dims()); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto is_test = ctx.Attr<bool>("is_test"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto dt = framework::ToMKLDNNDataType(input->type()); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto fmt = input->format(); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto exclude_padding = ctx.Attr<bool>("exclusive"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto src_md = mkldnn::memory::desc(src_tz, dt, fmt); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       /* create memory descriptor for pooling without specified format
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |        * ('any') which lets a primitive (pooling in this case) choose | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |        * the memory format preferred for best performance | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |        */ | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       const auto dst_md = | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any); | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     if (ceil_mode) { |  |  |  |       auto mkldnn_paddings = ToMkldnnPadding(paddings); | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |       CorrectOutputSize(src_dims, dst_dims, ksize, paddings, strides, |  |  |  | 
 | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                         mkldnn_paddings[1]); |  |  |  |       const bool ceil_mode = ctx.Attr<bool>("ceil_mode"); | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       if (ceil_mode) { | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                           mkldnn_paddings[1]); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       } | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       this->AcquireForwardPrimitiveDescriptor( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           is_test ? mkldnn::prop_kind::forward_inference | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                   : mkldnn::prop_kind::forward_training, | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           pooling_type == "max" | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |               ? mkldnn::algorithm::pooling_max | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |               : (exclude_padding | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                      ? mkldnn::algorithm::pooling_avg_exclude_padding | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                      : mkldnn::algorithm::pooling_avg_include_padding), | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           src_md, dst_md, strides, ksize, mkldnn_paddings[0], | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |           mkldnn_paddings[1]); | 
			
		
	
		
		
			
				
					
					|  |  |  |     } |  |  |  |     } | 
			
		
	
		
		
			
				
					
					|  |  |  |     this->AcquireForwardPrimitiveDescriptor( |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         is_test ? mkldnn::prop_kind::forward_inference |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |                 : mkldnn::prop_kind::forward_training, |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         pooling_type == "max" |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |             ? mkldnn::algorithm::pooling_max |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |             : (exclude_padding |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |                    ? mkldnn::algorithm::pooling_avg_exclude_padding |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |                    : mkldnn::algorithm::pooling_avg_include_padding), |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]); |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |   } |  |  |  |   } | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |   PoolingMKLDNNHandler( |  |  |  |   PoolingMKLDNNHandler( | 
			
		
	
	
		
		
			
				
					|  |  | 
 |