From 6fc9a9fd690e2d5fe48f2b39ed2575a04ef32103 Mon Sep 17 00:00:00 2001
From: sweetsky0901 <work@yq01-idl-gpu-online20.yq01.baidu.com>
Date: Tue, 28 Nov 2017 23:15:09 +0800
Subject: [PATCH] modify for del T2 and doc update

---
 paddle/operators/math/unpooling.cc            | 20 +++++-----
 paddle/operators/math/unpooling.cu            | 39 +++++++++----------
 paddle/operators/math/unpooling.h             |  4 +-
 paddle/operators/unpool_op.cc                 | 19 +++++----
 paddle/operators/unpool_op.cu.cc              |  8 ++--
 paddle/operators/unpool_op.h                  |  8 ++--
 .../paddle/v2/fluid/tests/test_unpool_op.py   |  4 +-
 7 files changed, 52 insertions(+), 50 deletions(-)

diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc
index ab6212f387..dbc3936971 100644
--- a/paddle/operators/math/unpooling.cc
+++ b/paddle/operators/math/unpooling.cc
@@ -19,8 +19,8 @@ namespace operators {
 namespace math {
 
 // All tensors are in NCHW format
-template <typename T, typename T2>
-class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> {
+template <typename T>
+class Unpool2dMaxFunctor<platform::CPUPlace, T> {
  public:
   void operator()(const platform::DeviceContext& context,
                   const framework::Tensor& input,
@@ -35,7 +35,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> {
     int input_feasize = input_height * input_width;
     int output_feasize = output_height * output_width;
     const T* input_data = input.data<T>();
-    const T2 * indices_data = indices.data<T2>();
+    const int * indices_data = indices.data<int>();
     T* output_data = output->mutable_data<T>(context.GetPlace());
     for (int b = 0; b < batch_size; ++b) {
       for (int c = 0; c < output_channels; ++c) {
@@ -54,8 +54,8 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> {
 
 
 
-template <class T, typename T2>
-class Unpool2dMaxGradFunctor<platform::CPUPlace, T, T2> {
+template <class T>
+class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
 public:
   void operator()(const platform::DeviceContext& context,
                   const framework::Tensor& input,
@@ -71,7 +71,7 @@ public:
     const int output_width = output.dims()[3];
     int input_feasize = input_height * input_width;
     int output_feasize = output_height * output_width;
-    const T2 * indices_data = indices.data<T2>();
+    const int * indices_data = indices.data<int>();
     const T* output_grad_data = output_grad.data<T>();
     T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
 
@@ -90,10 +90,10 @@ public:
   }
 };
 
-template class Unpool2dMaxGradFunctor<platform::CPUPlace, float, int>;
-template class Unpool2dMaxGradFunctor<platform::CPUPlace, double, int>;
-template class Unpool2dMaxFunctor<platform::CPUPlace, float, int>;
-template class Unpool2dMaxFunctor<platform::CPUPlace, double, int>;
+template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
+template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
+template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
+template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
 
 }  // namespace math
 }  // namespace operators
diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu
index 99e6fd052a..9cdd61f6d5 100644
--- a/paddle/operators/math/unpooling.cu
+++ b/paddle/operators/math/unpooling.cu
@@ -19,10 +19,10 @@ namespace paddle {
 namespace operators {
 namespace math {
 
-template <typename T, typename T2>
+template <typename T>
 __global__ void KernelUnpool2dMax(const int nthreads,
                                   const T* input_data,
-                                  const T2 * indices_data,
+                                  const int * indices_data,
                                   const int input_height,
                                   const int input_width,
                                   const int channels,
@@ -45,10 +45,10 @@ __global__ void KernelUnpool2dMax(const int nthreads,
       output_data[out_offset + out_index] = input_data[i];
     }
 }
-template <typename T, typename T2>
+template <typename T>
 __global__ void KernelUnpool2dMaxGrad(const int nthreads,
                                       const T* input_data,
-                                      const T2* indices_data,
+                                      const int* indices_data,
                                       const int input_height,
                                       const int input_width,
                                       const int channels,
@@ -76,8 +76,8 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
 /*
  * All tensors are in NCHW format.
  */
-template <typename T, typename T2>
-class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
+template <typename T>
+class Unpool2dMaxFunctor<platform::GPUPlace, T> {
  public:
   void operator()(const platform::DeviceContext& context,
                   const framework::Tensor& input,
@@ -90,15 +90,14 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
     const int output_height = output->dims()[2];
     const int output_width = output->dims()[3];
     const T* input_data = input.data<T>();
-    const T2 * indices_data = indices.data<T2>();
+    const int * indices_data = indices.data<int>();
     T* output_data = output->mutable_data<T>(context.GetPlace());
-    int nthreads = batch_size * output_channels * input_height * input_width;
     int threads = 1024;
     int grid =  (input.numel() + threads - 1) / threads;
     KernelUnpool2dMax<
-        T, T2><<<grid, threads, 0,
+        T><<<grid, threads, 0,
              reinterpret_cast<const platform::CUDADeviceContext&>(context)
-                 .stream()>>>(nthreads, input_data, indices_data,
+                 .stream()>>>(input.numel(), input_data, indices_data,
                               input_height, input_width, output_channels,
                               output_data, output_height, output_width);
   }
@@ -106,8 +105,8 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
 /*
  * All tensors are in NCHW format.
  */
-template <typename T, typename T2>
-class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
+template <typename T>
+class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
  public:
   void operator()(const platform::DeviceContext& context,
                   const framework::Tensor& input,
@@ -122,18 +121,16 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
     const int output_height = output.dims()[2];
     const int output_width = output.dims()[3];
     const T* input_data = input.data<T>();
-    const T2 * indices_data = indices.data<T2>();
+    const int * indices_data = indices.data<int>();
     const T* output_data = output.data<T>();
     const T* output_grad_data = output_grad.data<T>();
     T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
-    int nthreads = batch_size * output_channels * input_height * input_width;
     int threads = 1024;
     int grid =  (input.numel() + threads - 1) / threads;
     KernelUnpool2dMaxGrad<
-        T, T2><<<grid, threads, 0,
+        T><<<grid, threads, 0,
              reinterpret_cast<const platform::CUDADeviceContext&>(context)
-                 .stream()>>>(
-                              nthreads, input_data, indices_data,
+                 .stream()>>>(input.numel(), input_data, indices_data,
                               input_height, input_width, output_channels,
                               output_data, output_grad_data,
                               output_height, output_width,
@@ -141,11 +138,11 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
   }
 };
 
-template class Unpool2dMaxGradFunctor<platform::GPUPlace, float, int>;
-template class Unpool2dMaxGradFunctor<platform::GPUPlace, double, int>;
+template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
+template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
 
-template class Unpool2dMaxFunctor<platform::GPUPlace, float, int>;
-template class Unpool2dMaxFunctor<platform::GPUPlace, double, int>;
+template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
+template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
 
 }  // namespace math
 }  // namespace operators
diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h
index e086b891a1..bf79354ed9 100644
--- a/paddle/operators/math/unpooling.h
+++ b/paddle/operators/math/unpooling.h
@@ -19,7 +19,7 @@ namespace paddle {
 namespace operators {
 namespace math {
 
-template <typename Place, typename T, typename T2>
+template <typename Place, typename T>
 
 class Unpool2dMaxFunctor {
  public:
@@ -29,7 +29,7 @@ class Unpool2dMaxFunctor {
                   framework::Tensor * output);
 };
 
-template <typename Place, class T, typename T2>
+template <typename Place, class T>
 class Unpool2dMaxGradFunctor {
  public:
   void operator()(const platform::DeviceContext& context,
diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc
index 49a5129188..2505148764 100644
--- a/paddle/operators/unpool_op.cc
+++ b/paddle/operators/unpool_op.cc
@@ -50,10 +50,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
         "(string), unpooling type, can be \"max\" for max-unpooling ")
         .InEnum({"max"});
     AddComment(R"DOC(
-          "Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
+          "Input shape: $(N, C_{in}, H_{in}, W_{in})$
+          Output shape: $(N, C_{out}, H_{out}, W_{out})$
+          Where
+          $$
+            H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
+            W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
+          $$
+          Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
           /07/iccv2011.pdf
-          PyTorch: http://pytorch.org/docs/master/nn.html?highlight=unpool#
-          torch.nn.MaxUnpool2d"
         )DOC");
   }
 };
@@ -125,9 +130,9 @@ namespace ops = paddle::operators;
 REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
             ops::UnpoolOpGrad);
 REGISTER_OP_CPU_KERNEL(unpool,
-              ops::UnpoolKernel<paddle::platform::CPUPlace, float, int>,
-              ops::UnpoolKernel<paddle::platform::CPUPlace, double, int>);
+              ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
+              ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
 REGISTER_OP_CPU_KERNEL(unpool_grad,
-            ops::UnpoolGradKernel<paddle::platform::CPUPlace, float, int>,
-            ops::UnpoolGradKernel<paddle::platform::CPUPlace, double, int>);
+            ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
+            ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>);
 
diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc
index 9b5ac667d3..d8214fc687 100644
--- a/paddle/operators/unpool_op.cu.cc
+++ b/paddle/operators/unpool_op.cu.cc
@@ -16,10 +16,10 @@ limitations under the License. */
 
 namespace ops = paddle::operators;
 REGISTER_OP_GPU_KERNEL(unpool,
-                ops::UnpoolKernel<paddle::platform::GPUPlace, float, int>,
-                ops::UnpoolKernel<paddle::platform::GPUPlace, double, int>);
+                ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
+                ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
 REGISTER_OP_GPU_KERNEL(unpool_grad,
                        ops::UnpoolGradKernel<paddle::platform::GPUPlace,
-                        float, int>,
+                        float>,
                        ops::UnpoolGradKernel<paddle::platform::GPUPlace,
-                        double, int>);
+                        double>);
diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h
index dfd4ef12b5..f618a7c0ba 100644
--- a/paddle/operators/unpool_op.h
+++ b/paddle/operators/unpool_op.h
@@ -21,7 +21,7 @@ limitations under the License. */
 namespace paddle {
 namespace operators {
 
-template <typename Place, typename T, typename T2>
+template <typename Place, typename T>
 class UnpoolKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
@@ -37,12 +37,12 @@ class UnpoolKernel : public framework::OpKernel<T> {
       math::SetConstant<Place, T> set_zero;
       set_zero(context.device_context(), out, static_cast<T>(0));
     }
-    math::Unpool2dMaxFunctor<Place, T, T2> unpool2d_max_forward;
+    math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
     unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
   }
 };
 
-template <typename Place, typename T, typename T2>
+template <typename Place, typename T>
 class UnpoolGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
@@ -64,7 +64,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
       in_x_grad->mutable_data<T>(context.GetPlace());
       zero(device_ctx, in_x_grad, static_cast<T>(0));
     }
-    math::Unpool2dMaxGradFunctor<Place, T, T2> unpool2d_max_backward;
+    math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
     unpool2d_max_backward(context.device_context(), *in_x, *in_y,
                           *out, *out_grad, in_x_grad);
   }
diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py
index b3c6c85025..292b9bc14a 100644
--- a/python/paddle/v2/fluid/tests/test_unpool_op.py
+++ b/python/paddle/v2/fluid/tests/test_unpool_op.py
@@ -50,7 +50,7 @@ class TestUnpoolOp(OpTest):
                         indices[nidx, cidx, i, j] = \
                                 (r_start + arg / self.ksize[1]) * wsize + \
                                 c_start + arg % self.ksize[1]
-        output = self.Unpool2d_forward_naive(input, indices, self.ksize, \
+        output = self.unpool2d_forward_naive(input, indices, self.ksize, \
                 self.strides, self.paddings).astype("float32")
         self.inputs = {'X': input.astype('float32'),
                        'Indices': indices.astype('int32')}
@@ -69,7 +69,7 @@ class TestUnpoolOp(OpTest):
         self.check_grad(['X'], 'Out')
 
     def init_test_case(self):
-        self.Unpool2d_forward_naive = unpool2dmax_forward_naive
+        self.unpool2d_forward_naive = unpool2dmax_forward_naive
         self.unpooling_type = "max"
         self.shape = [6, 4, 5, 5]
         self.ksize = [3, 3]