From 6fc9a9fd690e2d5fe48f2b39ed2575a04ef32103 Mon Sep 17 00:00:00 2001 From: sweetsky0901 <work@yq01-idl-gpu-online20.yq01.baidu.com> Date: Tue, 28 Nov 2017 23:15:09 +0800 Subject: [PATCH] modify for del T2 and doc update --- paddle/operators/math/unpooling.cc | 20 +++++----- paddle/operators/math/unpooling.cu | 39 +++++++++---------- paddle/operators/math/unpooling.h | 4 +- paddle/operators/unpool_op.cc | 19 +++++---- paddle/operators/unpool_op.cu.cc | 8 ++-- paddle/operators/unpool_op.h | 8 ++-- .../paddle/v2/fluid/tests/test_unpool_op.py | 4 +- 7 files changed, 52 insertions(+), 50 deletions(-) diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc index ab6212f387..dbc3936971 100644 --- a/paddle/operators/math/unpooling.cc +++ b/paddle/operators/math/unpooling.cc @@ -19,8 +19,8 @@ namespace operators { namespace math { // All tensors are in NCHW format -template <typename T, typename T2> -class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> { +template <typename T> +class Unpool2dMaxFunctor<platform::CPUPlace, T> { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -35,7 +35,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> { int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; const T* input_data = input.data<T>(); - const T2 * indices_data = indices.data<T2>(); + const int * indices_data = indices.data<int>(); T* output_data = output->mutable_data<T>(context.GetPlace()); for (int b = 0; b < batch_size; ++b) { for (int c = 0; c < output_channels; ++c) { @@ -54,8 +54,8 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> { -template <class T, typename T2> -class Unpool2dMaxGradFunctor<platform::CPUPlace, T, T2> { +template <class T> +class Unpool2dMaxGradFunctor<platform::CPUPlace, T> { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -71,7 +71,7 @@ public: const int output_width = output.dims()[3]; int input_feasize = input_height * input_width; int output_feasize = output_height * output_width; - const T2 * indices_data = indices.data<T2>(); + const int * indices_data = indices.data<int>(); const T* output_grad_data = output_grad.data<T>(); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); @@ -90,10 +90,10 @@ public: } }; -template class Unpool2dMaxGradFunctor<platform::CPUPlace, float, int>; -template class Unpool2dMaxGradFunctor<platform::CPUPlace, double, int>; -template class Unpool2dMaxFunctor<platform::CPUPlace, float, int>; -template class Unpool2dMaxFunctor<platform::CPUPlace, double, int>; +template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>; +template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>; +template class Unpool2dMaxFunctor<platform::CPUPlace, float>; +template class Unpool2dMaxFunctor<platform::CPUPlace, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu index 99e6fd052a..9cdd61f6d5 100644 --- a/paddle/operators/math/unpooling.cu +++ b/paddle/operators/math/unpooling.cu @@ -19,10 +19,10 @@ namespace paddle { namespace operators { namespace math { -template <typename T, typename T2> +template <typename T> __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, - const T2 * indices_data, + const int * indices_data, const int input_height, const int input_width, const int channels, @@ -45,10 +45,10 @@ __global__ void KernelUnpool2dMax(const int nthreads, output_data[out_offset + out_index] = input_data[i]; } } -template <typename T, typename T2> +template <typename T> __global__ void KernelUnpool2dMaxGrad(const int nthreads, const T* input_data, - const T2* indices_data, + const int* indices_data, const int input_height, const int input_width, const int channels, @@ -76,8 +76,8 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads, /* * All tensors are in NCHW format. */ -template <typename T, typename T2> -class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> { +template <typename T> +class Unpool2dMaxFunctor<platform::GPUPlace, T> { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -90,15 +90,14 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> { const int output_height = output->dims()[2]; const int output_width = output->dims()[3]; const T* input_data = input.data<T>(); - const T2 * indices_data = indices.data<T2>(); + const int * indices_data = indices.data<int>(); T* output_data = output->mutable_data<T>(context.GetPlace()); - int nthreads = batch_size * output_channels * input_height * input_width; int threads = 1024; int grid = (input.numel() + threads - 1) / threads; KernelUnpool2dMax< - T, T2><<<grid, threads, 0, + T><<<grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context) - .stream()>>>(nthreads, input_data, indices_data, + .stream()>>>(input.numel(), input_data, indices_data, input_height, input_width, output_channels, output_data, output_height, output_width); } @@ -106,8 +105,8 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> { /* * All tensors are in NCHW format. */ -template <typename T, typename T2> -class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> { +template <typename T> +class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, @@ -122,18 +121,16 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> { const int output_height = output.dims()[2]; const int output_width = output.dims()[3]; const T* input_data = input.data<T>(); - const T2 * indices_data = indices.data<T2>(); + const int * indices_data = indices.data<int>(); const T* output_data = output.data<T>(); const T* output_grad_data = output_grad.data<T>(); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); - int nthreads = batch_size * output_channels * input_height * input_width; int threads = 1024; int grid = (input.numel() + threads - 1) / threads; KernelUnpool2dMaxGrad< - T, T2><<<grid, threads, 0, + T><<<grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context) - .stream()>>>( - nthreads, input_data, indices_data, + .stream()>>>(input.numel(), input_data, indices_data, input_height, input_width, output_channels, output_data, output_grad_data, output_height, output_width, @@ -141,11 +138,11 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> { } }; -template class Unpool2dMaxGradFunctor<platform::GPUPlace, float, int>; -template class Unpool2dMaxGradFunctor<platform::GPUPlace, double, int>; +template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>; +template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>; -template class Unpool2dMaxFunctor<platform::GPUPlace, float, int>; -template class Unpool2dMaxFunctor<platform::GPUPlace, double, int>; +template class Unpool2dMaxFunctor<platform::GPUPlace, float>; +template class Unpool2dMaxFunctor<platform::GPUPlace, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h index e086b891a1..bf79354ed9 100644 --- a/paddle/operators/math/unpooling.h +++ b/paddle/operators/math/unpooling.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { namespace math { -template <typename Place, typename T, typename T2> +template <typename Place, typename T> class Unpool2dMaxFunctor { public: @@ -29,7 +29,7 @@ class Unpool2dMaxFunctor { framework::Tensor * output); }; -template <typename Place, class T, typename T2> +template <typename Place, class T> class Unpool2dMaxGradFunctor { public: void operator()(const platform::DeviceContext& context, diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index 49a5129188..2505148764 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -50,10 +50,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "(string), unpooling type, can be \"max\" for max-unpooling ") .InEnum({"max"}); AddComment(R"DOC( - "Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 + "Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where + $$ + H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\ + W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1] + $$ + Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 /07/iccv2011.pdf - PyTorch: http://pytorch.org/docs/master/nn.html?highlight=unpool# - torch.nn.MaxUnpool2d" )DOC"); } }; @@ -125,9 +130,9 @@ namespace ops = paddle::operators; REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, ops::UnpoolOpGrad); REGISTER_OP_CPU_KERNEL(unpool, - ops::UnpoolKernel<paddle::platform::CPUPlace, float, int>, - ops::UnpoolKernel<paddle::platform::CPUPlace, double, int>); + ops::UnpoolKernel<paddle::platform::CPUPlace, float>, + ops::UnpoolKernel<paddle::platform::CPUPlace, double>); REGISTER_OP_CPU_KERNEL(unpool_grad, - ops::UnpoolGradKernel<paddle::platform::CPUPlace, float, int>, - ops::UnpoolGradKernel<paddle::platform::CPUPlace, double, int>); + ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>, + ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>); diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc index 9b5ac667d3..d8214fc687 100644 --- a/paddle/operators/unpool_op.cu.cc +++ b/paddle/operators/unpool_op.cu.cc @@ -16,10 +16,10 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(unpool, - ops::UnpoolKernel<paddle::platform::GPUPlace, float, int>, - ops::UnpoolKernel<paddle::platform::GPUPlace, double, int>); + ops::UnpoolKernel<paddle::platform::GPUPlace, float>, + ops::UnpoolKernel<paddle::platform::GPUPlace, double>); REGISTER_OP_GPU_KERNEL(unpool_grad, ops::UnpoolGradKernel<paddle::platform::GPUPlace, - float, int>, + float>, ops::UnpoolGradKernel<paddle::platform::GPUPlace, - double, int>); + double>); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h index dfd4ef12b5..f618a7c0ba 100644 --- a/paddle/operators/unpool_op.h +++ b/paddle/operators/unpool_op.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template <typename Place, typename T, typename T2> +template <typename Place, typename T> class UnpoolKernel : public framework::OpKernel<T> { public: void Compute(const framework::ExecutionContext& context) const override { @@ -37,12 +37,12 @@ class UnpoolKernel : public framework::OpKernel<T> { math::SetConstant<Place, T> set_zero; set_zero(context.device_context(), out, static_cast<T>(0)); } - math::Unpool2dMaxFunctor<Place, T, T2> unpool2d_max_forward; + math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward; unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); } }; -template <typename Place, typename T, typename T2> +template <typename Place, typename T> class UnpoolGradKernel : public framework::OpKernel<T> { public: void Compute(const framework::ExecutionContext& context) const override { @@ -64,7 +64,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> { in_x_grad->mutable_data<T>(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast<T>(0)); } - math::Unpool2dMaxGradFunctor<Place, T, T2> unpool2d_max_backward; + math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward; unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out, *out_grad, in_x_grad); } diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py index b3c6c85025..292b9bc14a 100644 --- a/python/paddle/v2/fluid/tests/test_unpool_op.py +++ b/python/paddle/v2/fluid/tests/test_unpool_op.py @@ -50,7 +50,7 @@ class TestUnpoolOp(OpTest): indices[nidx, cidx, i, j] = \ (r_start + arg / self.ksize[1]) * wsize + \ c_start + arg % self.ksize[1] - output = self.Unpool2d_forward_naive(input, indices, self.ksize, \ + output = self.unpool2d_forward_naive(input, indices, self.ksize, \ self.strides, self.paddings).astype("float32") self.inputs = {'X': input.astype('float32'), 'Indices': indices.astype('int32')} @@ -69,7 +69,7 @@ class TestUnpoolOp(OpTest): self.check_grad(['X'], 'Out') def init_test_case(self): - self.Unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpool2d_forward_naive = unpool2dmax_forward_naive self.unpooling_type = "max" self.shape = [6, 4, 5, 5] self.ksize = [3, 3]