add unfold op (new op),test=develop (#17944)
	
		
	
				
					
				
			* add unfold op test=develop * fix divide bug in python3 when calculating output width and height test=develop * add name=None in python api, move redundant code into inline function * try to trigger ci for this code test=developlite
							parent
							
								
									b5c35ae3e7
								
							
						
					
					
						commit
						40885c225b
					
				| @ -0,0 +1,184 @@ | ||||
| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  *     Unless required by applicable law or agreed to in writing, software | ||||
|  *     distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  *     See the License for the specific language governing permissions and | ||||
|  *     limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/unfold_op.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| class UnfoldOpMaker : public framework::OpProtoAndCheckerMaker { | ||||
|  public: | ||||
|   void Make() override { | ||||
|     AddInput("X", | ||||
|              "Tensor, " | ||||
|              "the input of unfold op. " | ||||
|              "The format of X is [N, C_in, H, W], " | ||||
|              "where N is the batch size, C_in is the input channels, " | ||||
|              "H is the height and W is the width"); | ||||
|     AddOutput( | ||||
|         "Y", | ||||
|         "Tensor, " | ||||
|         "the output of unfold op. " | ||||
|         "The format of Y is [N, C_in*filter_height*filter_width, " | ||||
|         "output_height*output_width], where N is the batch size, " | ||||
|         "C_in is the input channels of X, filter_height and filter_width is " | ||||
|         "height and width of the filtering kernel, output_height and " | ||||
|         "output_width " | ||||
|         "is the calculated height and width of output feature map."); | ||||
|     AddAttr<std::vector<int>>( | ||||
|         "kernel_sizes", | ||||
|         "vector<int>, the kernel sizes of the convolution operator."); | ||||
|     AddAttr<std::vector<int>>( | ||||
|         "strides", "vector<int>, the strides of the convolution operator."); | ||||
|     AddAttr<std::vector<int>>( | ||||
|         "paddings", | ||||
|         "vector<int>, the paddings applied to pad the feature map."); | ||||
|     AddAttr<std::vector<int>>( | ||||
|         "dilations", "vector<int>, the dilations of the convolution operator."); | ||||
|     AddComment(R"DOC( | ||||
| **Unfold Operator** | ||||
| 
 | ||||
| This Operator is used to extract sliding local blocks from a batched input tensor, also known | ||||
| as im2col when operated on batched 2D image tensor. For each block under the convolution filter, | ||||
| all element will be rearranged as a column. While the convolution filter silding over the input | ||||
| feature map, a series of such columns will be formed.  | ||||
|     )DOC"); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class UnfoldOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE(ctx->HasInput("X"), | ||||
|                    "Input(X) of UnfoldOp should not be null"); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput("Y"), | ||||
|                    "Output(Y) of UnfoldOp should not be null"); | ||||
|     auto in_dims = ctx->GetInputDim("X"); | ||||
|     std::vector<int> kernel_sizes = | ||||
|         ctx->Attrs().Get<std::vector<int>>("kernel_sizes"); | ||||
|     std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); | ||||
|     std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); | ||||
|     std::vector<int> dilations = | ||||
|         ctx->Attrs().Get<std::vector<int>>("dilations"); | ||||
| 
 | ||||
|     // Only [N, C, H, W] input supported now
 | ||||
|     PADDLE_ENFORCE( | ||||
|         in_dims.size() == 4, | ||||
|         "Input shold be 4-D tensor of format [N, C, H, W], but get %u", | ||||
|         in_dims.size()); | ||||
|     PADDLE_ENFORCE( | ||||
|         in_dims.size() - kernel_sizes.size() == 2U, | ||||
|         "The dims of X should be larger than that of kernel_sizes " | ||||
|         "by a number of 2, due to the batch size and input channel dim. " | ||||
|         "But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2", | ||||
|         in_dims.size(), kernel_sizes.size()); | ||||
|     PADDLE_ENFORCE_EQ( | ||||
|         strides.size(), kernel_sizes.size(), | ||||
|         "The dims of strides shold be the same with that of kernel_sizes. " | ||||
|         "But recieved dims(strides: %u) != dims(kernel_sizes: %u).", | ||||
|         strides.size(), kernel_sizes.size()); | ||||
|     PADDLE_ENFORCE_EQ( | ||||
|         paddings.size(), 2 * strides.size(), | ||||
|         "The dims of paddings should be 2 times of that of strides. " | ||||
|         "But recieved dims(paddings: %u) != 2*dims(strides: %u).", | ||||
|         paddings.size(), strides.size()); | ||||
|     PADDLE_ENFORCE_EQ( | ||||
|         strides.size(), dilations.size(), | ||||
|         "The dims of strides shold be the same with that of dilations. " | ||||
|         "But recieved dims(strides: %u) != dims(dilations: %u).", | ||||
|         strides.size(), dilations.size()); | ||||
| 
 | ||||
|     std::vector<int> out_dims; | ||||
|     out_dims.push_back(in_dims[0]); | ||||
| 
 | ||||
|     int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1]; | ||||
|     out_dims.push_back(output_channels); | ||||
| 
 | ||||
|     int output_height = | ||||
|         CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0], | ||||
|                        paddings[2], strides[0]); | ||||
|     int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], | ||||
|                                       paddings[1], paddings[3], strides[1]); | ||||
|     int output_col_length = output_height * output_width; | ||||
|     out_dims.push_back(output_col_length); | ||||
| 
 | ||||
|     ctx->SetOutputDim("Y", framework::make_ddim(out_dims)); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   framework::OpKernelType GetExpectedKernelType( | ||||
|       const framework::ExecutionContext& ctx) const override { | ||||
|     return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), | ||||
|                                    ctx.device_context()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class UnfoldGradOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
| 
 | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), | ||||
|                    "The gradient of Y should not be null"); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null"); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), | ||||
|                    "The gradient of X should not be null"); | ||||
|     ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   framework::OpKernelType GetExpectedKernelType( | ||||
|       const framework::ExecutionContext& ctx) const override { | ||||
|     return framework::OpKernelType( | ||||
|         ctx.Input<framework::Tensor>(framework::GradVarName("Y"))->type(), | ||||
|         ctx.device_context()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class UnfoldGradDescMaker : public framework::SingleGradOpDescMaker { | ||||
|  public: | ||||
|   using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; | ||||
| 
 | ||||
|  protected: | ||||
|   std::unique_ptr<framework::OpDesc> Apply() const override { | ||||
|     std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); | ||||
|     op->SetType("unfold_grad"); | ||||
|     op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); | ||||
|     op->SetInput("X", Input("X")); | ||||
|     op->SetOutput(framework::GradVarName("X"), InputGrad("X")); | ||||
|     op->SetAttrMap(Attrs()); | ||||
|     return op; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(UnfoldGradOpNoNeedBufferVarsInference, | ||||
|                                       "X"); | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker, | ||||
|                   ops::UnfoldGradDescMaker); | ||||
| REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp, | ||||
|                   ops::UnfoldGradOpNoNeedBufferVarsInference); | ||||
| 
 | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     unfold, ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, float>, | ||||
|     ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, double>); | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     unfold_grad, | ||||
|     ops::UnfoldGradOpKernel<paddle::platform::CPUDeviceContext, float>, | ||||
|     ops::UnfoldGradOpKernel<paddle::platform::CPUDeviceContext, double>); | ||||
| @ -0,0 +1,26 @@ | ||||
| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| Indicesou may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/unfold_op.h" | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| 
 | ||||
| REGISTER_OP_CUDA_KERNEL( | ||||
|     unfold, ops::UnfoldOpKernel<paddle::platform::CUDADeviceContext, float>, | ||||
|     ops::UnfoldOpKernel<paddle::platform::CUDADeviceContext, double>); | ||||
| 
 | ||||
| REGISTER_OP_CUDA_KERNEL( | ||||
|     unfold_grad, | ||||
|     ops::UnfoldGradOpKernel<paddle::platform::CUDADeviceContext, float>, | ||||
|     ops::UnfoldGradOpKernel<paddle::platform::CUDADeviceContext, double>); | ||||
| @ -0,0 +1,127 @@ | ||||
| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  *     Unless required by applicable law or agreed to in writing, software | ||||
|  *     distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  *     See the License for the specific language governing permissions and | ||||
|  *     limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <vector> | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| #include "paddle/fluid/operators/math/im2col.h" | ||||
| #include "paddle/fluid/operators/math/math_function.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| using Tensor = framework::Tensor; | ||||
| 
 | ||||
| inline int CalcOutputSize(int input_size, int filter_size, int dilation, | ||||
|                           int padding1, int padding2, int stride) { | ||||
|   const int dkernel = dilation * (filter_size - 1) + 1; | ||||
|   int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1; | ||||
|   PADDLE_ENFORCE(output_size > 0, | ||||
|                  "Due to the settings of padding(%d, %d), filter_size(%d), " | ||||
|                  "dilation(%d) and " | ||||
|                  "stride(%d), the output size is less than 0, please check " | ||||
|                  "again. Input_size:%d", | ||||
|                  padding1, padding2, filter_size, dilation, stride, input_size); | ||||
| 
 | ||||
|   return output_size; | ||||
| } | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class UnfoldOpKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     const Tensor* input = ctx.Input<Tensor>("X"); | ||||
|     const int batch_size = static_cast<int>(input->dims()[0]); | ||||
|     Tensor* output = ctx.Output<Tensor>("Y"); | ||||
|     output->mutable_data<T>(ctx.GetPlace()); | ||||
| 
 | ||||
|     std::vector<int> kernel_sizes = ctx.Attr<std::vector<int>>("kernel_sizes"); | ||||
|     std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); | ||||
|     std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); | ||||
|     std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); | ||||
| 
 | ||||
|     math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col; | ||||
|     auto& dev_ctx = ctx.template device_context<DeviceContext>(); | ||||
| 
 | ||||
|     auto input_dims = input->dims(); | ||||
| 
 | ||||
|     int output_height = | ||||
|         CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0], | ||||
|                        paddings[0], paddings[2], strides[0]); | ||||
|     int output_width = | ||||
|         CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1], | ||||
|                        paddings[1], paddings[3], strides[1]); | ||||
| 
 | ||||
|     framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]}); | ||||
|     framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0], | ||||
|                                          kernel_sizes[1], output_height, | ||||
|                                          output_width}); | ||||
| 
 | ||||
|     for (int i = 0; i < batch_size; i++) { | ||||
|       Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); | ||||
|       Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); | ||||
|       im2col(dev_ctx, in_batch, dilations, strides, paddings, &out_batch); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class UnfoldGradOpKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     const Tensor* output_grad = ctx.Input<Tensor>(framework::GradVarName("Y")); | ||||
|     Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||||
|     input_grad->mutable_data<T>(ctx.GetPlace()); | ||||
| 
 | ||||
|     if ((!output_grad) || (!input_grad)) return; | ||||
| 
 | ||||
|     std::vector<int> kernel_sizes = ctx.Attr<std::vector<int>>("kernel_sizes"); | ||||
|     std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); | ||||
|     std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); | ||||
|     std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); | ||||
| 
 | ||||
|     const int batch_size = static_cast<int>(input_grad->dims()[0]); | ||||
| 
 | ||||
|     auto input_dims = input_grad->dims(); | ||||
| 
 | ||||
|     int output_height = | ||||
|         CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0], | ||||
|                        paddings[0], paddings[2], strides[0]); | ||||
|     int output_width = | ||||
|         CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1], | ||||
|                        paddings[1], paddings[3], strides[1]); | ||||
| 
 | ||||
|     framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]}); | ||||
|     framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0], | ||||
|                                          kernel_sizes[1], output_height, | ||||
|                                          output_width}); | ||||
| 
 | ||||
|     math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im; | ||||
|     auto& dev_ctx = ctx.template device_context<DeviceContext>(); | ||||
| 
 | ||||
|     math::SetConstant<DeviceContext, T> set_zero; | ||||
|     set_zero(dev_ctx, input_grad, static_cast<T>(0)); | ||||
|     for (int i = 0; i < batch_size; i++) { | ||||
|       Tensor out_grad_batch = | ||||
|           output_grad->Slice(i, i + 1).Resize(output_matrix_shape); | ||||
|       Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); | ||||
|       col2im(dev_ctx, out_grad_batch, dilations, strides, paddings, | ||||
|              &in_grad_batch); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,102 @@ | ||||
| #   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import math | ||||
| import numpy as np | ||||
| import unittest | ||||
| from op_test import OpTest | ||||
| 
 | ||||
| 
 | ||||
| class TestUnfoldOp(OpTest): | ||||
|     """ | ||||
|     This is for test on unfold Op | ||||
|     """ | ||||
| 
 | ||||
|     def init_data(self): | ||||
|         self.batch_size = 3 | ||||
|         self.input_channels = 3 | ||||
|         self.input_height = 20 | ||||
|         self.input_width = 20 | ||||
|         self.kernel_sizes = [3, 3] | ||||
|         self.strides = [1, 1] | ||||
|         self.paddings = [1, 1, 1, 1] | ||||
|         self.dilations = [1, 1] | ||||
|         input_shape = [ | ||||
|             self.batch_size, self.input_channels, self.input_height, | ||||
|             self.input_width | ||||
|         ] | ||||
|         self.x = np.random.rand(*input_shape).astype(np.float32) | ||||
| 
 | ||||
|     def calc_unfold(self): | ||||
|         output_shape = [0] * 3 | ||||
|         output_shape[0] = self.batch_size | ||||
|         output_shape[1] = self.input_channels * self.kernel_sizes[ | ||||
|             0] * self.kernel_sizes[1] | ||||
|         dkernel_h = self.dilations[0] * (self.kernel_sizes[0] - 1) + 1 | ||||
|         dkernel_w = self.dilations[1] * (self.kernel_sizes[1] - 1) + 1 | ||||
|         out_height = int((self.input_height + self.paddings[0] + | ||||
|                           self.paddings[2] - dkernel_h) / self.strides[0]) + 1 | ||||
|         out_width = int((self.input_width + self.paddings[1] + self.paddings[3] | ||||
|                          - dkernel_w) / self.strides[1]) + 1 | ||||
|         output_shape[2] = out_height * out_width | ||||
|         output = np.zeros(output_shape).astype(np.float32) | ||||
|         ############ calculate output ############## | ||||
|         for i in range(output_shape[0]): | ||||
|             for j in range(output_shape[1]): | ||||
|                 for k in range(output_shape[2]): | ||||
|                     h_out = int(k / out_width) | ||||
|                     w_out = k % out_width | ||||
|                     w_offset = j % self.kernel_sizes[1] | ||||
|                     h_offset = int(j / | ||||
|                                    self.kernel_sizes[1]) % self.kernel_sizes[0] | ||||
|                     c_in = int(j / | ||||
|                                (self.kernel_sizes[0] * self.kernel_sizes[1])) | ||||
|                     h_in = h_offset * self.dilations[0] + h_out * self.strides[ | ||||
|                         0] - self.paddings[0] | ||||
|                     w_in = w_offset * self.dilations[1] + w_out * self.strides[ | ||||
|                         1] - self.paddings[1] | ||||
|                     if (h_in>=0 and h_in<self.input_height) and \ | ||||
|                          (w_in>=0 and w_in<self.input_width): | ||||
|                         output[i, j, k] = self.x[i, c_in, h_in, w_in] | ||||
| 
 | ||||
|         self.outputs = output | ||||
| 
 | ||||
|     def set_data(self): | ||||
|         self.init_data() | ||||
|         self.calc_unfold() | ||||
| 
 | ||||
|         self.inputs = {'X': self.x} | ||||
|         self.attrs = { | ||||
|             'kernel_sizes': self.kernel_sizes, | ||||
|             'paddings': self.paddings, | ||||
|             'dilations': self.dilations, | ||||
|             'strides': self.strides | ||||
|         } | ||||
|         self.outputs = {'Y': self.outputs} | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.op_type = 'unfold' | ||||
|         self.set_data() | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
|     def test_check_grad(self): | ||||
|         self.check_grad(['X'], 'Y') | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue