From 4b06d8db9179c74a35582c85f782e8c268d361a6 Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Tue, 26 Sep 2017 20:06:12 +0800
Subject: [PATCH] fix globalPooling type (int => bool)

---
 paddle/operators/pool_op.cc | 47 ++++++++++++++++++++-----------------
 paddle/operators/pool_op.h  |  8 +++----
 2 files changed, 30 insertions(+), 25 deletions(-)

diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc
index a5e731cc66..9959b3ec07 100644
--- a/paddle/operators/pool_op.cc
+++ b/paddle/operators/pool_op.cc
@@ -35,7 +35,7 @@ class PoolOp : public framework::OperatorWithKernel {
 
     auto in_x = ctx.Input<Tensor>("X");
     auto out = ctx.Output<Tensor>("Out");
-    int global_pooling = Attr<int>("globalPooling");
+    bool global_pooling = Attr<bool>("globalPooling");
     std::string pooling_type = Attr<std::string>("poolingType");
     std::vector<int> ksize = Attr<std::vector<int>>("ksize");
     std::vector<int> strides = Attr<std::vector<int>>("strides");
@@ -45,6 +45,15 @@ class PoolOp : public framework::OperatorWithKernel {
                    "pooling_type should be 'max' or 'avg'");
     PADDLE_ENFORCE(in_x->dims().size() == 4 || in_x->dims().size() == 5,
                    "Pooling intput should be 4-D or 5-D");
+
+    if (global_pooling) {
+      ksize.resize(static_cast<size_t>(in_x->dims().size()) - 2);
+      for (size_t i = 0; i < ksize.size(); ++i)
+        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
+    }
+
+    PADDLE_ENFORCE(in_x->dims().size() == static_cast<size_t>(ksize.size() + 2),
+                   "Input size and Pooling size should be consistent.");
     PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
                    "Pooling size should be 2 elements. or 3 elements.");
     PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
@@ -52,12 +61,6 @@ class PoolOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
                       "paddings size and pooling size should be the same.");
 
-    if (global_pooling == 1) {
-      ksize.resize(static_cast<size_t>(in_x->dims().size()) - 2);
-      for (size_t i = 0; i < ksize.size(); ++i)
-        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
-    }
-
     std::vector<int64_t> output_shape({in_x->dims()[0], in_x->dims()[1]});
     for (size_t i = 0; i < ksize.size(); ++i) {
       output_shape.push_back(OutputSizePool(in_x->dims()[i + 2], ksize[i],
@@ -103,15 +106,16 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
                          "poolingType of pooling operator."
                          "str constant equal to 'max' or 'avg'");
     AddAttr<std::vector<int>>(
-        "ksize", "pooling size(height, width) of pooling operator.")
-        .AddCustomChecker(GreaterThanChecker_pool({0, 0}));
-    AddAttr<int>(
+        "ksize",
+        "Pooling size(depth, height, width) of pooling operator."
+        "If globalPooling = true, ksize is ignored and need not be specified.");
+    AddAttr<bool>(
         "globalPooling",
         "whether to use the globalPooling."
-        "int constant equal to 0 or 1"
-        "default 0"
-        "If globalPooling = 1, ksize is ignored and need not be specified.")
-        .SetDefault(0);
+        "int constant equal to false or true"
+        "default false"
+        "If globalPooling = true, ksize is ignored and need not be specified.")
+        .SetDefault(false);
     AddAttr<std::vector<int>>("strides",
                               "strides(height, width) of pooling operator."
                               "default {1,1}")
@@ -177,15 +181,16 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
                          "poolingType of pooling operator."
                          "str constant equal to 'max' or 'avg'");
     AddAttr<std::vector<int>>(
-        "ksize", "pooling size(depth, height, width) of pooling operator.")
-        .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0}));
-    AddAttr<int>(
+        "ksize",
+        "pooling size(depth, height, width) of pooling operator."
+        "If globalPooling = true, ksize is ignored and need not be specified.");
+    AddAttr<bool>(
         "globalPooling",
         "whether to use the globalPooling."
-        "int constant equal to 0 or 1"
-        "default 0"
-        "If globalPooling = 1, ksize is ignored and need not be specified.")
-        .SetDefault(0);
+        "int constant equal to false or true"
+        "default false"
+        "If globalPooling = true, ksize is ignored and need not be specified.")
+        .SetDefault(false);
     AddAttr<std::vector<int>>(
         "strides",
         "strides(depth, height, width) of pooling operator."
diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h
index 9471282205..73c9721624 100644
--- a/paddle/operators/pool_op.h
+++ b/paddle/operators/pool_op.h
@@ -31,12 +31,12 @@ class PoolKernel : public framework::OpKernel {
     const Tensor* in_x = context.Input<Tensor>("X");
     Tensor* out = context.Output<Tensor>("Out");
 
-    int global_pooling = context.Attr<int>("globalPooling");
+    bool global_pooling = context.Attr<bool>("globalPooling");
     std::string pooling_type = context.Attr<std::string>("poolingType");
     std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
     std::vector<int> strides = context.Attr<std::vector<int>>("strides");
     std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
-    if (global_pooling == 1) {
+    if (global_pooling) {
       for (size_t i = 0; i < ksize.size(); ++i) {
         ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
       }
@@ -92,13 +92,13 @@ class PoolGradKernel : public framework::OpKernel {
         context.Input<Tensor>(framework::GradVarName("Out"));
     Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
 
-    int global_pooling = context.Attr<int>("globalPooling");
+    bool global_pooling = context.Attr<bool>("globalPooling");
     std::string pooling_type = context.Attr<std::string>("poolingType");
     std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
     std::vector<int> strides = context.Attr<std::vector<int>>("strides");
     std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
 
-    if (global_pooling == 1) {
+    if (global_pooling) {
       for (size_t i = 0; i < ksize.size(); ++i)
         ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
     }