From 9621213230c9caeac216f4796473f257e5065ec1 Mon Sep 17 00:00:00 2001
From: xzl <zlx_hg@163.com>
Date: Wed, 18 Oct 2017 17:41:25 +0800
Subject: [PATCH] add max-pool-with-mask c++ impl

---
 paddle/gserver/layers/PoolLayer.cpp           |  9 +++--
 paddle/gserver/layers/PoolLayer.h             |  2 ++
 paddle/gserver/layers/PoolProjection.cpp      | 36 ++++++++++++++++++-
 paddle/gserver/layers/PoolProjection.h        | 13 ++++++-
 paddle/gserver/layers/PoolProjectionLayer.cpp | 10 +++++-
 paddle/gserver/layers/Projection.h            | 13 +++++++
 6 files changed, 78 insertions(+), 5 deletions(-)

diff --git a/paddle/gserver/layers/PoolLayer.cpp b/paddle/gserver/layers/PoolLayer.cpp
index 7b932d5a76..c5f4143a5b 100644
--- a/paddle/gserver/layers/PoolLayer.cpp
+++ b/paddle/gserver/layers/PoolLayer.cpp
@@ -44,14 +44,19 @@ bool PoolLayer::init(const LayerMap& layerMap,
   strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
   confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
   outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
-
+  with_mask_ = false;
+  if (poolType_ == "max-pool-with-mask") {
+    setOutput("mask", &mask_);
+    with_mask_ = true;
+  }
   return true;
 }
 
 Layer* PoolLayer::create(const LayerConfig& config) {
   CHECK_EQ(config.inputs_size(), 1);
   const std::string& pool = config.inputs(0).pool_conf().pool_type();
-  if (pool == "max-projection" || pool == "avg-projection") {
+  if (pool == "max-projection" || pool == "avg-projection" ||
+      pool == "max-pool-with-mask") {
     return new PoolProjectionLayer(config);
 #ifdef PADDLE_WITH_CUDA
   } else if (CudnnPoolLayer::typeCheck(pool)) {
diff --git a/paddle/gserver/layers/PoolLayer.h b/paddle/gserver/layers/PoolLayer.h
index d43292ad2d..780bfd0bce 100644
--- a/paddle/gserver/layers/PoolLayer.h
+++ b/paddle/gserver/layers/PoolLayer.h
@@ -37,6 +37,8 @@ protected:
   int confPaddingY_;
 
   std::string poolType_;
+  bool with_mask_;
+  Argument mask_;
 
 public:
   explicit PoolLayer(const LayerConfig& config) : Layer(config) {}
diff --git a/paddle/gserver/layers/PoolProjection.cpp b/paddle/gserver/layers/PoolProjection.cpp
index d90b438448..ccf58228a7 100644
--- a/paddle/gserver/layers/PoolProjection.cpp
+++ b/paddle/gserver/layers/PoolProjection.cpp
@@ -36,6 +36,10 @@ PoolProjection::PoolProjection(const ProjectionConfig& config,
   strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
   confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
   outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
+  with_mask_ = false;
+  if (poolType_ == "max-pool-with-mask") {
+    with_mask_ = true;
+  }
 }
 
 size_t PoolProjection::getSize() {
@@ -73,6 +77,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
     return new MaxPoolProjection(config, parameter, useGpu);
   } else if (pool == "avg-projection") {
     return new AvgPoolProjection(config, parameter, useGpu);
+  } else if (pool == "max-pool-with-mask") {
+    return new MaxPoolProjection(config, parameter, useGpu);
   } else {
     LOG(FATAL) << "Unknown pool type: " << pool;
     return nullptr;
@@ -84,6 +90,10 @@ void MaxPoolProjection::forward() {
   CHECK_EQ(width, out_->value->getWidth());
   MatrixPtr inputV = in_->value;
   MatrixPtr outV = out_->value;
+  MatrixPtr maskV = out_->value;
+  if (with_mask_) {
+    maskV = mask_->value;
+  }
   outV->maxPoolForward(*inputV,
                        imgSizeY_,
                        imgSize_,
@@ -95,7 +105,9 @@ void MaxPoolProjection::forward() {
                        outputY_,
                        outputX_,
                        confPaddingY_,
-                       confPadding_);
+                       confPadding_,
+                       maskV,
+                       with_mask_);
 }
 
 void MaxPoolProjection::backward(const UpdateCallback& callback) {
@@ -168,4 +180,26 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) {
                              confPaddingY_,
                              confPadding_);
 }
+
+void MaxWithMaskPoolProjection::forward() {
+  size_t width = getSize();
+  CHECK_EQ(width, out_->value->getWidth());
+  MatrixPtr inputV = in_->value;
+  MatrixPtr outV = out_->value;
+  MatrixPtr maskV = mask_->value;
+  outV->maxPoolForward(*inputV,
+                       imgSizeY_,
+                       imgSize_,
+                       channels_,
+                       sizeX_,
+                       sizeY_,
+                       strideY_,
+                       stride_,
+                       outputY_,
+                       outputX_,
+                       confPaddingY_,
+                       confPadding_,
+                       maskV,
+                       with_mask_);
+}
 }  // namespace paddle
diff --git a/paddle/gserver/layers/PoolProjection.h b/paddle/gserver/layers/PoolProjection.h
index 9a75f465f6..d240d5c87e 100644
--- a/paddle/gserver/layers/PoolProjection.h
+++ b/paddle/gserver/layers/PoolProjection.h
@@ -28,6 +28,7 @@ protected:
   int confPaddingY_, confPadding_;
   size_t channels_;
   std::string poolType_;
+  bool with_mask_;
 
 public:
   PoolProjection(const ProjectionConfig& config,
@@ -37,7 +38,6 @@ public:
   static PoolProjection* create(const ProjectionConfig& config,
                                 ParameterPtr parameter,
                                 bool useGpu);
-
   const std::string& getPoolType() const { return poolType_; }
 
   size_t getSize();
@@ -64,4 +64,15 @@ public:
   virtual void forward();
   virtual void backward(const UpdateCallback& callback = nullptr);
 };
+
+class MaxWithMaskPoolProjection : public MaxPoolProjection {
+public:
+  MaxWithMaskPoolProjection(const ProjectionConfig& config,
+                            ParameterPtr parameter,
+                            bool useGpu)
+      : MaxPoolProjection(config, parameter, useGpu) {}
+
+  virtual void forward();
+};
+
 }  // namespace paddle
diff --git a/paddle/gserver/layers/PoolProjectionLayer.cpp b/paddle/gserver/layers/PoolProjectionLayer.cpp
index ed5011ab89..5cd61a9ea8 100644
--- a/paddle/gserver/layers/PoolProjectionLayer.cpp
+++ b/paddle/gserver/layers/PoolProjectionLayer.cpp
@@ -51,8 +51,16 @@ void PoolProjectionLayer::forward(PassType passType) {
   const Argument& in = getInput(0);
   int batchSize = in.value->getHeight();
   int size = getSize();
+
+  if (with_mask_) {
+    resetSpecifyOutput(mask_,
+                       batchSize,
+                       size,
+                       /* isValueClean */ false,
+                       /* isGradClean */ true);
+  }
   resetOutput(batchSize, size);
-  poolProjection_->forward(&in, &output_, passType);
+  poolProjection_->forward(&in, &output_, &mask_, passType);
 }
 
 void PoolProjectionLayer::backward(const UpdateCallback& callback) {
diff --git a/paddle/gserver/layers/Projection.h b/paddle/gserver/layers/Projection.h
index 778a7fe13d..f60a9b931b 100644
--- a/paddle/gserver/layers/Projection.h
+++ b/paddle/gserver/layers/Projection.h
@@ -69,6 +69,17 @@ public:
     forward();
   }
 
+  void forward(const Argument* in,
+               const Argument* out,
+               const Argument* mask,
+               PassType passType) {
+    in_ = in;
+    out_ = out;
+    mask_ = mask;
+    passType_ = passType;
+    forward();
+  }
+
   virtual void prefetch(const Argument* in) {}
   virtual void forward() = 0;
   virtual void backward(const UpdateCallback& callback) = 0;
@@ -130,6 +141,8 @@ protected:
   const Argument* in_;
   /// Store `out` passed to forward()
   const Argument* out_;
+  /// Store `mask` passed to forward()
+  const Argument* mask_;
   /// Store `passType` passed to forward()
   PassType passType_;
   /// Layer forward function