From 4b4af84e677da837cc809a10be41517c401f465a Mon Sep 17 00:00:00 2001
From: sneaxiy <sneaxiy@126.com>
Date: Tue, 16 Oct 2018 07:09:23 +0000
Subject: [PATCH] test=develop

---
 paddle/fluid/API.spec                         |   1 +
 paddle/fluid/operators/math/algorithm.h       |  46 ++++++
 paddle/fluid/operators/sequence_reverse_op.cc |  29 ++++
 paddle/fluid/operators/sequence_reverse_op.cu |  25 +++
 paddle/fluid/operators/sequence_reverse_op.h  | 155 ++++++++++++++++++
 python/paddle/fluid/layers/nn.py              |  29 ++++
 .../tests/unittests/test_sequence_reverse.py  |  69 ++++++++
 7 files changed, 354 insertions(+)
 create mode 100644 paddle/fluid/operators/sequence_reverse_op.cc
 create mode 100644 paddle/fluid/operators/sequence_reverse_op.cu
 create mode 100644 paddle/fluid/operators/sequence_reverse_op.h
 create mode 100644 python/paddle/fluid/tests/unittests/test_sequence_reverse.py

diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 212724a0c7..2d34902e10 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -171,6 +171,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
 paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
 paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
 paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
 paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h
index 262469beea..2e75b6abce 100644
--- a/paddle/fluid/operators/math/algorithm.h
+++ b/paddle/fluid/operators/math/algorithm.h
@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
   return -1;
 }
 
+template <typename T>
+HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
+#ifdef __CUDA_ARCH__
+  // The following code is from
+  // https://en.cppreference.com/w/cpp/algorithm/lower_bound
+  auto *first = x;
+  int64_t count = static_cast<int64_t>(num);
+  while (count > 0) {
+    int64_t step = (count >> 1);
+    auto *it = first + step;
+    if (*it < val) {
+      first = ++it;
+      count -= (step + 1);
+    } else {
+      count = step;
+    }
+  }
+  return static_cast<size_t>(first - x);
+#else
+  return static_cast<size_t>(std::lower_bound(x, x + num, val) - x);
+#endif
+}
+
+template <typename T>
+HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
+#ifdef __CUDA_ARCH__
+  // The following code is from
+  // https://en.cppreference.com/w/cpp/algorithm/upper_bound
+  auto *first = x;
+  int64_t count = static_cast<int64_t>(num);
+  while (count > 0) {
+    auto step = (count >> 1);
+    auto *it = first + step;
+    if (val < *it) {
+      count = step;
+    } else {
+      first = ++it;
+      count -= (step + 1);
+    }
+  }
+  return static_cast<size_t>(first - x);
+#else
+  return static_cast<size_t>(std::upper_bound(x, x + num, val) - x);
+#endif
+}
+
 }  // namespace math
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/sequence_reverse_op.cc b/paddle/fluid/operators/sequence_reverse_op.cc
new file mode 100644
index 0000000000..1428cca1a6
--- /dev/null
+++ b/paddle/fluid/operators/sequence_reverse_op.cc
@@ -0,0 +1,29 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/sequence_reverse_op.h"
+
+namespace ops = paddle::operators;
+
+REGISTER_OPERATOR(sequence_reverse, ops::SequenceReverseOp,
+                  ops::SequenceReverseOpMaker,
+                  ops::SequenceReverseGradOpDescMaker);
+
+REGISTER_OP_CPU_KERNEL(
+    sequence_reverse,
+    ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
+    ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int>,
+    ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
+    ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, float>,
+    ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, double>);
diff --git a/paddle/fluid/operators/sequence_reverse_op.cu b/paddle/fluid/operators/sequence_reverse_op.cu
new file mode 100644
index 0000000000..ce65f4799e
--- /dev/null
+++ b/paddle/fluid/operators/sequence_reverse_op.cu
@@ -0,0 +1,25 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/sequence_reverse_op.h"
+
+namespace ops = paddle::operators;
+
+REGISTER_OP_CUDA_KERNEL(
+    sequence_reverse,
+    ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
+    ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int>,
+    ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
+    ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, float>,
+    ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, double>);
diff --git a/paddle/fluid/operators/sequence_reverse_op.h b/paddle/fluid/operators/sequence_reverse_op.h
new file mode 100644
index 0000000000..ec11a548c5
--- /dev/null
+++ b/paddle/fluid/operators/sequence_reverse_op.h
@@ -0,0 +1,155 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/algorithm.h"
+#include "paddle/fluid/platform/for_range.h"
+
+namespace paddle {
+namespace operators {
+
+class SequenceReverseOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext *ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
+    PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
+
+    auto x_dim = ctx->GetInputDim("X");
+    PADDLE_ENFORCE_GE(x_dim.size(), 2,
+                      "Rank of Input(X) must be not less than 2.");
+
+    ctx->SetOutputDim("Y", x_dim);
+    ctx->ShareLoD("X", "Y");
+  }
+};
+
+class SequenceReverseOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("X", "The input LoDTensor of sequence_reverse op.");
+    AddOutput("Y", "The output LoDTensor of sequence_reverse op.");
+    AddComment(R"DOC(
+SequenceReverse Operator.
+
+Reverse each sequence in input X along dim 0.
+
+Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where:
+
+X.data() = [
+  [1, 2, 3, 4],
+  [5, 6, 7, 8], # the 0-th sequence with length 2
+  [9, 10, 11, 12],
+  [13, 14, 15, 16],
+  [17, 18, 19, 20] # the 1-st sequence with length 3
+]
+
+The output Y would be a LoDTensor sharing the same dims and lod with input X,
+and:
+
+Y.data() = [
+  [5, 6, 7, 8],
+  [1, 2, 3, 4], # the reversed 0-th sequence with length 2
+  [17, 18, 19, 20],
+  [13, 14, 15, 16],
+  [9, 10, 11, 12] # the reversed 1-st sequence with length 3
+]
+
+This Operator is useful to build a reverse dynamic RNN network.
+    )DOC");
+  }
+};
+
+template <typename T>
+struct SequenceReverseFunctor {
+  SequenceReverseFunctor(const T *x, T *y, const size_t *lod, size_t lod_count,
+                         size_t row_numel)
+      : x_(x), y_(y), lod_(lod), lod_count_(lod_count), row_numel_(row_numel) {}
+
+  HOSTDEVICE void operator()(size_t idx_x) const {
+    auto row_idx_x = idx_x / row_numel_;
+    auto lod_idx = math::UpperBound(lod_, lod_count_, row_idx_x);
+    auto row_idx_y = lod_[lod_idx - 1] + (lod_[lod_idx] - 1 - row_idx_x);
+    auto idx_y = row_idx_y * row_numel_ + idx_x % row_numel_;
+    y_[idx_y] = x_[idx_x];
+  }
+
+  const T *x_;
+  T *y_;
+  const size_t *lod_;
+  size_t lod_count_;
+  size_t row_numel_;
+};
+
+template <typename DeviceContext, typename T>
+class SequenceReverseOpKernel : public framework::OpKernel<T> {
+  using LoDTensor = framework::LoDTensor;
+
+ public:
+  void Compute(const framework::ExecutionContext &ctx) const override {
+    auto &x = *ctx.Input<LoDTensor>("X");
+    auto *y = ctx.Output<LoDTensor>("Y");
+
+    PADDLE_ENFORCE_EQ(x.lod().size(), 1,
+                      "SequenceReverse Op only support one level lod.");
+
+    auto &dev_ctx = ctx.template device_context<DeviceContext>();
+    const size_t *lod;
+    size_t lod_count = x.lod()[0].size();
+
+#ifdef PADDLE_WITH_CUDA
+    if (platform::is_gpu_place(ctx.GetPlace())) {
+      lod = x.lod()[0].CUDAData(ctx.GetPlace());
+    } else {
+#endif
+      lod = x.lod()[0].data();
+#ifdef PADDLE_WITH_CUDA
+    }
+#endif
+
+    size_t limit = static_cast<size_t>(x.numel());
+    size_t row_numel = static_cast<size_t>(limit / x.dims()[0]);
+    auto *x_data = x.data<T>();
+    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
+
+    PADDLE_ENFORCE_NE(x_data, y_data,
+                      "SequenceReverse Op does not support in-place operation");
+
+    SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
+                                      row_numel);
+    platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
+    for_range(functor);
+  }
+};
+
+class SequenceReverseGradOpDescMaker : public framework::SingleGradOpDescMaker {
+ public:
+  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+  std::unique_ptr<framework::OpDesc> Apply() const override {
+    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
+    op->SetType("sequence_reverse");
+    op->SetInput("X", OutputGrad("Y"));
+    op->SetOutput("Y", InputGrad("X"));
+    op->SetAttrMap(Attrs());
+    return op;
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index 43aa4a9e7c..aaeb9b666e 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -151,6 +151,7 @@ __all__ = [
     'mul',
     'sigmoid_cross_entropy_with_logits',
     'maxout',
+    'sequence_reverse',
 ]
 
 
@@ -7134,3 +7135,31 @@ def maxout(x, groups, name=None):
         attrs={"groups": groups},
         outputs={"Out": out})
     return out
+
+
+@templatedoc()
+def sequence_reverse(x, name=None):
+    """ 
+    ${comment}
+
+    Args:
+        x(${x_type}): ${x_comment}
+        name(basestring|None): Name of the output.
+
+    Returns:
+        out(${y_type}): ${y_comment}
+    """
+    helper = LayerHelper("sequence_reverse", **locals())
+
+    if name is None:
+        out = helper.create_tmp_variable(dtype=x.dtype)
+    else:
+        out = helper.create_variable(
+            name=name, dtype=x.dtype, persistable=False)
+
+    helper.append_op(
+        type="sequence_reverse",
+        inputs={"X": x},
+        outputs={"Y": out},
+        attrs=dict())
+    return out
diff --git a/python/paddle/fluid/tests/unittests/test_sequence_reverse.py b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py
new file mode 100644
index 0000000000..eebd25e097
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+from op_test import OpTest
+import numpy as np
+
+
+class TestSequenceReverseBase(OpTest):
+    def initParameters(self):
+        pass
+
+    def setUp(self):
+        self.size = (10, 3, 4)
+        self.lod = [2, 3, 5]
+        self.dtype = 'float32'
+        self.initParameters()
+        self.op_type = 'sequence_reverse'
+        self.x = np.random.random(self.size).astype(self.dtype)
+        self.y = self.get_output()
+
+        self.inputs = {'X': (self.x, [self.lod, ]), }
+        self.outputs = {'Y': (self.y, [self.lod, ]), }
+
+    def get_output(self):
+        tmp_x = np.reshape(self.x, newshape=[self.x.shape[0], -1])
+        tmp_y = np.ndarray(tmp_x.shape).astype(self.dtype)
+        prev_idx = 0
+        for cur_len in self.lod:
+            idx_range = range(prev_idx, prev_idx + cur_len)
+            tmp_y[idx_range, :] = np.flip(tmp_x[idx_range, :], 0)
+            prev_idx += cur_len
+
+        return np.reshape(tmp_y, newshape=self.x.shape).astype(self.dtype)
+
+    def test_output(self):
+        self.check_output(0)
+
+    def test_grad(self):
+        self.check_grad(['X'], 'Y')
+
+
+class TestSequenceReserve1(TestSequenceReverseBase):
+    def initParameters(self):
+        self.size = (12, 10)
+        self.lod = [4, 5, 3]
+
+
+class TestSequenceReverse2(TestSequenceReverseBase):
+    def initParameters(self):
+        self.size = (12, 10)
+        self.lod = [12]
+
+
+if __name__ == '__main__':
+    unittest.main()