From 8284947b820e5b50e8af047bc40bc37b7b379830 Mon Sep 17 00:00:00 2001
From: whs <wanghaoshuang@baidu.com>
Date: Tue, 17 Jul 2018 22:31:56 +0800
Subject: [PATCH] Fix infershape of im2sequence. (#12183)

---
 paddle/fluid/operators/im2sequence_op.cc | 12 ++----------
 paddle/fluid/operators/im2sequence_op.h  |  5 +++--
 2 files changed, 5 insertions(+), 12 deletions(-)

diff --git a/paddle/fluid/operators/im2sequence_op.cc b/paddle/fluid/operators/im2sequence_op.cc
index c8c7f36536..8efd43928a 100644
--- a/paddle/fluid/operators/im2sequence_op.cc
+++ b/paddle/fluid/operators/im2sequence_op.cc
@@ -33,22 +33,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
 
     PADDLE_ENFORCE_EQ(in_dim.size(), 4,
                       "Input(X) format must be 4D tensor, eg., NCHW.");
-    int batch_size = in_dim[0];
     int img_channels = in_dim[1];
-    int img_height = in_dim[2];
-    int img_width = in_dim[3];
 
     auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
     auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
     auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
 
-    int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
-                                         paddings[2], strides[0]);
-    int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
-                                        paddings[3], strides[1]);
-
-    ctx->SetOutputDim("Out", {batch_size * output_height * output_width,
-                              img_channels * kernels[0] * kernels[1]});
+    ctx->SetOutputDim("Out",
+                      {in_dim[0], img_channels * kernels[0] * kernels[1]});
   }
 };
 
diff --git a/paddle/fluid/operators/im2sequence_op.h b/paddle/fluid/operators/im2sequence_op.h
index 5bfb91db18..4a99428194 100644
--- a/paddle/fluid/operators/im2sequence_op.h
+++ b/paddle/fluid/operators/im2sequence_op.h
@@ -109,12 +109,13 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
       }
       out->set_lod(lod);
     } else {
-      out->mutable_data<T>(ctx.GetPlace());
       int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
                                            paddings[2], strides[0]);
       int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
                                           paddings[3], strides[1]);
-
+      out->mutable_data<T>({batch_size * output_height * output_width,
+                            img_channels * kernels[0] * kernels[1]},
+                           ctx.GetPlace());
       const std::vector<int> dilations({1, 1});
       auto out_dims = out->dims();
       out->Resize({batch_size, out->numel() / batch_size});