From 6be273cbdbf3fe46fe5a4b4af787977b9bd59929 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Fri, 24 Aug 2018 22:40:38 +0800
Subject: [PATCH 1/5] add seq mode lstm

---
 paddle/fluid/operators/fusion_lstm_op.cc | 52 ++++++++++++++++++++----
 1 file changed, 45 insertions(+), 7 deletions(-)

diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index 3888333ec5..870292827d 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -15,10 +15,14 @@ limitations under the License. */
 #include "paddle/fluid/operators/fusion_lstm_op.h"
 #include <string>
 #include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/cpu_vec.h"
 #include "paddle/fluid/operators/math/detail/activation_functions.h"
 #include "paddle/fluid/operators/math/fc_compute.h"
 #include "paddle/fluid/operators/math/lstm_compute.h"
 #include "paddle/fluid/operators/math/sequence2batch.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+DEFINE_bool(seq_mode, true, "Use sequence mode");
 
 namespace paddle {
 namespace operators {
@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
   ctx->ShareLoD("X", "Hidden");
   ctx->ShareLoD("X", "Cell");
 
-  int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
+  int xx_width;
+  if (FLAGS_seq_mode) {
+    xx_width = wx_dims[1];
+  } else {
+    xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
+  }
   ctx->SetOutputDim("XX", {x_dims[0], xx_width});
   ctx->ShareLoD("X", "XX");
 }
@@ -205,10 +214,34 @@ inline void ReorderInitState(const DeviceContext& ctx,
   row_shuffle(ctx, src, index_lod, dst, indexed_src);
 }
 
-template <typename DeviceContext, typename T>
+template <typename T>
 class FuisonLSTMKernel : public framework::OpKernel<T> {
  public:
-  void Compute(const framework::ExecutionContext& ctx) const override {
+  void SeqCompute(const framework::ExecutionContext& ctx) const {
+    using DeviceContext = paddle::platform::CPUDeviceContext;
+    auto* x = ctx.Input<LoDTensor>("X");
+    auto* wx = ctx.Input<Tensor>("WeightX");
+    auto* wh = ctx.Input<Tensor>("WeightH");
+    auto* bias = ctx.Input<Tensor>("Bias");
+
+    auto* xx = ctx.Output<LoDTensor>("XX");
+
+    auto x_dims = x->dims();    // T x M
+    auto wh_dims = wh->dims();  // D x 4D
+    const int M = x_dims[1];    // x frame size
+    const int D4 = wh_dims[1];
+
+    const T* x_data = x->data<T>();
+    const T* wx_data = wx->data<T>();
+    T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
+
+    auto blas = math::GetBlas<DeviceContext, T>(ctx);
+    math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
+                                      xx_data, bias->data<T>());
+  }
+
+  void BatchCompute(const framework::ExecutionContext& ctx) const {
+    using DeviceContext = platform::CPUDeviceContext;
     auto* x = ctx.Input<LoDTensor>("X");
     auto* wx = ctx.Input<Tensor>("WeightX");
     auto* wh = ctx.Input<Tensor>("WeightH");
@@ -339,6 +372,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     // restore the output cell state in LoDTensor from the batch cell
     to_seq(dev_ctx, batch_cell, cell_out);
   }
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    if (FLAGS_seq_mode) {
+      SeqCompute(ctx);
+    } else {
+      BatchCompute(ctx);
+    }
+  }
 };
 
 }  // namespace operators
@@ -348,7 +388,5 @@ namespace ops = paddle::operators;
 REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
                   paddle::framework::DefaultGradOpDescMaker<true>);
 
-REGISTER_OP_CPU_KERNEL(
-    fusion_lstm,
-    ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
-    ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
+REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
+                       ops::FuisonLSTMKernel<double>);

From 607c41952e78d8c5d489a75590204f802d392ee5 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sun, 26 Aug 2018 16:10:45 +0800
Subject: [PATCH 2/5] compute gates

---
 paddle/fluid/operators/fusion_lstm_op.cc | 87 +++++++++++++++++++++++-
 1 file changed, 84 insertions(+), 3 deletions(-)

diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index 870292827d..604c6f1839 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -220,24 +220,105 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
   void SeqCompute(const framework::ExecutionContext& ctx) const {
     using DeviceContext = paddle::platform::CPUDeviceContext;
     auto* x = ctx.Input<LoDTensor>("X");
+    auto* h0 = ctx.Input<Tensor>("H0");
+    auto* c0 = ctx.Input<Tensor>("C0");
     auto* wx = ctx.Input<Tensor>("WeightX");
     auto* wh = ctx.Input<Tensor>("WeightH");
     auto* bias = ctx.Input<Tensor>("Bias");
 
     auto* xx = ctx.Output<LoDTensor>("XX");
+    auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
+    auto* cell_out = ctx.Output<LoDTensor>("Cell");
 
-    auto x_dims = x->dims();    // T x M
-    auto wh_dims = wh->dims();  // D x 4D
-    const int M = x_dims[1];    // x frame size
+    auto x_lod = x->lod();
+    auto x_dims = x->dims();            // T x M
+    auto wh_dims = wh->dims();          // D x 4D
+    const int N = x_lod[0].size() - 1;  // batch size
+    const int M = x_dims[1];            // x frame size
+    const int D = wh_dims[0];
+    const int D2 = D * 2;
+    const int D3 = D * 3;
     const int D4 = wh_dims[1];
 
     const T* x_data = x->data<T>();
+    const T* h0_data = h0 ? h0->data<T>() : NULL;
+    const T* c0_data = c0 ? c0->data<T>() : NULL;
     const T* wx_data = wx->data<T>();
+    const T* wh_data = wh->data<T>();
     T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
+    T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
+    T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
 
     auto blas = math::GetBlas<DeviceContext, T>(ctx);
     math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
                                       xx_data, bias->data<T>());
+
+    for (int i = 0; i < N; ++i) {
+      int seq_len = x_lod[0][i + 1] - x_lod[0][i];
+      const T* prev_cell_data = NULL;
+      const T* prev_hidden_data = NULL;
+      int tstart = 0;
+      if (h0_data) {
+        prev_hidden_data = h0_data + i * D;
+        prev_cell_data = c0_data + i * D;
+      } else {
+        // W_ch, W_ih, W_fh, W_oh
+        // actgate
+        math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
+        // ch gate
+        math::vec_tanh<T>(D, xx_data, xx_data);
+        // cell out= input*tilde
+        blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
+        // hidden out= act_state(cellout) * outgate
+        // act state
+        math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
+        blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
+
+        // prev
+        prev_hidden_data = hidden_out_data;
+        prev_cell_data = cell_out_data;
+        tstart = 1;
+
+        // move offset
+        xx_data = xx_data + D4;
+        hidden_out_data = hidden_out_data + D;
+        cell_out_data = cell_out_data + D;
+      }
+      for (int step = tstart; step < seq_len; ++step) {
+        blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
+                  prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data,
+                  D4);
+
+        // W_ch, W_ih, W_fh, W_oh
+        // actgate
+        math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
+        // ch gate
+        math::vec_tanh<T>(D, xx_data, xx_data);
+
+        // a = forget * prev_cell
+        blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
+
+        // b = input * tilde
+        blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
+
+        // cell out= a+b
+        blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
+
+        // hidden out= act_state(cellout) * outgate
+        // act state
+        math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
+        blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
+
+        // prev
+        prev_hidden_data = hidden_out_data;
+        prev_cell_data = cell_out_data;
+
+        // move offset
+        xx_data = xx_data + D4;
+        hidden_out_data = hidden_out_data + D;
+        cell_out_data = cell_out_data + D;
+      }
+    }
   }
 
   void BatchCompute(const framework::ExecutionContext& ctx) const {

From 4b28fab8c94863d5ff24ce4c59ff31bb5d06b4ee Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sun, 26 Aug 2018 18:24:00 +0800
Subject: [PATCH 3/5] enable more acts

---
 paddle/fluid/operators/fusion_lstm_op.cc      | 34 ++++++++++++-------
 .../tests/unittests/test_fusion_lstm_op.py    |  2 +-
 2 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index 604c6f1839..97852e2928 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
     auto* cell_out = ctx.Output<LoDTensor>("Cell");
 
+    std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
+    auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
+    auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
+    auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
+    if (platform::jit::MayIUse(platform::jit::avx)) {
+      math::VecActivations<T, platform::jit::avx> act_functor;
+      act_gate = act_functor(act_gate_str);
+      act_cell = act_functor(act_cell_str);
+      act_cand = act_functor(act_cand_str);
+    } else {
+      math::VecActivations<T, platform::jit::isa_any> act_functor;
+      act_gate = act_functor(act_gate_str);
+      act_cell = act_functor(act_cell_str);
+      act_cand = act_functor(act_cand_str);
+    }
+
     auto x_lod = x->lod();
     auto x_dims = x->dims();            // T x M
     auto wh_dims = wh->dims();          // D x 4D
@@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
         prev_cell_data = c0_data + i * D;
       } else {
         // W_ch, W_ih, W_fh, W_oh
-        // actgate
-        math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
-        // ch gate
-        math::vec_tanh<T>(D, xx_data, xx_data);
+        act_gate(D3, xx_data + D, xx_data + D);
+        act_cand(D, xx_data, xx_data);
         // cell out= input*tilde
         blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
         // hidden out= act_state(cellout) * outgate
-        // act state
-        math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
+        act_cell(D, cell_out_data, xx_data + D2);
         blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
 
         // prev
@@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
                   D4);
 
         // W_ch, W_ih, W_fh, W_oh
-        // actgate
-        math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
-        // ch gate
-        math::vec_tanh<T>(D, xx_data, xx_data);
+        act_gate(D3, xx_data + D, xx_data + D);
+        act_cand(D, xx_data, xx_data);
 
         // a = forget * prev_cell
         blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
@@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
         blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
 
         // hidden out= act_state(cellout) * outgate
-        // act state
-        math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
+        act_cell(D, cell_out_data, xx_data + D2);
         blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
 
         // prev
diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
index 9d8bef677f..d807f0a8b6 100644
--- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
@@ -45,7 +45,7 @@ def fusion_lstm(
 
 class TestLstmOp(OpTest):
     def set_argument(self):
-        self.lod = [[2, 3, 2]]
+        pass
 
     def setUp(self):
         self.op_type = 'fusion_lstm'

From 1777cd09f652e18c85a5017058cd29c4794446fa Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sun, 26 Aug 2018 18:42:20 +0800
Subject: [PATCH 4/5] refine fusion lstm op test

---
 .../tests/unittests/test_fusion_lstm_op.py    | 61 +++++++++++--------
 1 file changed, 35 insertions(+), 26 deletions(-)

diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
index d807f0a8b6..19f22fc7bd 100644
--- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
@@ -43,13 +43,13 @@ def fusion_lstm(
         act_cell, act_cand)
 
 
-class TestLstmOp(OpTest):
-    def set_argument(self):
+class TestFusionLSTMOp(OpTest):
+    def set_conf(self):
         pass
 
     def setUp(self):
         self.op_type = 'fusion_lstm'
-        self.lod = [[2, 3, 2]]
+        self.lod = [[2, 3, 5, 4]]
         self.M = 8
         self.D = 16
         self.has_initial_state = False
@@ -58,33 +58,33 @@ class TestLstmOp(OpTest):
         self.act_cell = 'tanh'
         self.act_cand = 'tanh'
         self.use_peepholes = False
-        self.set_argument()
+        self.set_conf()
 
         T = sum(self.lod[0])
         bs = len(self.lod[0])
 
-        x = np.random.normal(size=(T, self.M)).astype('float64')
+        x = np.random.normal(size=(T, self.M)).astype('float32')
         if self.has_initial_state:
-            h0 = np.random.normal(size=(bs, self.D)).astype('float64')
-            c0 = np.random.normal(size=(bs, self.D)).astype('float64')
+            h0 = np.random.normal(size=(bs, self.D)).astype('float32')
+            c0 = np.random.normal(size=(bs, self.D)).astype('float32')
         else:
-            h0 = np.zeros((bs, self.D)).astype('float64')
-            c0 = np.zeros((bs, self.D)).astype('float64')
+            h0 = np.zeros((bs, self.D)).astype('float32')
+            c0 = np.zeros((bs, self.D)).astype('float32')
 
-        wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
+        wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32')
 
         if self.use_peepholes:
-            b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
+            b = np.random.normal(size=(1, 7 * self.D)).astype('float32')
         else:
-            b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
+            b = np.random.normal(size=(1, 4 * self.D)).astype('float32')
         w_b = np.copy(b[:, 0:4 * self.D])
         w_c = b[:, 4 * self.D:] if self.use_peepholes else None
 
         # this is the weight of fc
-        wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
+        wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32')
         # this is the bias of fc
         # and it should be manually added into the bias of this fusion LSTM
-        bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
+        bx = np.random.normal(size=(1, 4 * self.D)).astype('float32')
         b[0, 0:4 * self.D] += bx[0, :]
         h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
                            self.is_reverse, ACTIVATION[self.act_gate],
@@ -114,35 +114,44 @@ class TestLstmOp(OpTest):
         }
 
     def test_check_output(self):
-        self.check_output(atol=1e-8)
+        self.check_output()
 
 
-class TestLstmOpInitReverse(TestLstmOp):
-    def set_argument(self):
+class TestFusionLSTMOpInit(TestFusionLSTMOp):
+    def set_conf(self):
         self.has_initial_state = True
-        self.is_reverse = True
 
 
-class TestLstmOpMD1(TestLstmOp):
-    def set_argument(self):
+# class TestFusionLSTMOpReverse(TestFusionLSTMOp):
+#     def set_conf(self):
+#         self.is_reverse = True
+
+# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
+#     def set_conf(self):
+#         self.has_initial_state = True
+#         self.is_reverse = True
+
+
+class TestFusionLSTMOpMD1(TestFusionLSTMOp):
+    def set_conf(self):
         self.M = 36
         self.D = 8
 
 
-class TestLstmOpMD2(TestLstmOp):
-    def set_argument(self):
+class TestFusionLSTMOpMD2(TestFusionLSTMOp):
+    def set_conf(self):
         self.M = 8
         self.D = 8
 
 
-class TestLstmOpMD3(TestLstmOp):
-    def set_argument(self):
+class TestFusionLSTMOpMD3(TestFusionLSTMOp):
+    def set_conf(self):
         self.M = 15
         self.D = 3
 
 
-class TestLstmOpBS1(TestLstmOp):
-    def set_argument(self):
+class TestFusionLSTMOpBS1(TestFusionLSTMOp):
+    def set_conf(self):
         self.lod = [[3]]
         self.D = 16
 

From e61cf3214da019ca1de1fb68ae143928877b4e62 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sun, 26 Aug 2018 21:00:56 +0800
Subject: [PATCH 5/5] complete reverse seq

---
 paddle/fluid/operators/fusion_lstm_op.cc      | 41 ++++++++++++-------
 .../tests/unittests/test_fusion_lstm_op.py    | 17 ++++----
 2 files changed, 36 insertions(+), 22 deletions(-)

diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index 97852e2928..e4e4ac8e33 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -229,6 +229,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     auto* xx = ctx.Output<LoDTensor>("XX");
     auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
     auto* cell_out = ctx.Output<LoDTensor>("Cell");
+    bool is_reverse = ctx.Attr<bool>("is_reverse");
 
     std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
     auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
@@ -247,8 +248,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     }
 
     auto x_lod = x->lod();
-    auto x_dims = x->dims();            // T x M
-    auto wh_dims = wh->dims();          // D x 4D
+    auto x_dims = x->dims();    // T x M
+    auto wh_dims = wh->dims();  // D x 4D
+    const int total_T = x_dims[0];
     const int N = x_lod[0].size() - 1;  // batch size
     const int M = x_dims[1];            // x frame size
     const int D = wh_dims[0];
@@ -266,17 +268,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
 
     auto blas = math::GetBlas<DeviceContext, T>(ctx);
-    math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
+    math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
                                       xx_data, bias->data<T>());
+    int xx_offset = D4;
+    int gate_offset = D;
+    if (is_reverse) {
+      const int offset = (total_T - 1) * D;
+      xx_data = xx_data + offset * 4;
+      hidden_out_data = hidden_out_data + offset;
+      cell_out_data = cell_out_data + offset;
+      xx_offset = -D4;
+      gate_offset = -D;
+    }
+
+    auto move_step = [&]() {
+      xx_data = xx_data + xx_offset;
+      hidden_out_data = hidden_out_data + gate_offset;
+      cell_out_data = cell_out_data + gate_offset;
+    };
 
     for (int i = 0; i < N; ++i) {
-      int seq_len = x_lod[0][i + 1] - x_lod[0][i];
+      int bid = is_reverse ? N - 1 - i : i;
+      int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
       const T* prev_cell_data = NULL;
       const T* prev_hidden_data = NULL;
       int tstart = 0;
       if (h0_data) {
-        prev_hidden_data = h0_data + i * D;
-        prev_cell_data = c0_data + i * D;
+        prev_hidden_data = h0_data + bid * D;
+        prev_cell_data = c0_data + bid * D;
       } else {
         // W_ch, W_ih, W_fh, W_oh
         act_gate(D3, xx_data + D, xx_data + D);
@@ -292,10 +311,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
         prev_cell_data = cell_out_data;
         tstart = 1;
 
-        // move offset
-        xx_data = xx_data + D4;
-        hidden_out_data = hidden_out_data + D;
-        cell_out_data = cell_out_data + D;
+        move_step();
       }
       for (int step = tstart; step < seq_len; ++step) {
         blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
@@ -323,10 +339,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
         prev_hidden_data = hidden_out_data;
         prev_cell_data = cell_out_data;
 
-        // move offset
-        xx_data = xx_data + D4;
-        hidden_out_data = hidden_out_data + D;
-        cell_out_data = cell_out_data + D;
+        move_step();
       }
     }
   }
diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
index 19f22fc7bd..5805bdf461 100644
--- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
@@ -122,14 +122,15 @@ class TestFusionLSTMOpInit(TestFusionLSTMOp):
         self.has_initial_state = True
 
 
-# class TestFusionLSTMOpReverse(TestFusionLSTMOp):
-#     def set_conf(self):
-#         self.is_reverse = True
-
-# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
-#     def set_conf(self):
-#         self.has_initial_state = True
-#         self.is_reverse = True
+class TestFusionLSTMOpReverse(TestFusionLSTMOp):
+    def set_conf(self):
+        self.is_reverse = True
+
+
+class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
+    def set_conf(self):
+        self.has_initial_state = True
+        self.is_reverse = True
 
 
 class TestFusionLSTMOpMD1(TestFusionLSTMOp):