From bf9302f95015db6cadf3e814cfc4f21ef8434a3d Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Thu, 13 Dec 2018 10:18:22 +0000
Subject: [PATCH] add lstm, peephole refer and test

---
 paddle/fluid/operators/jit/gen_base.cc        |   5 -
 paddle/fluid/operators/jit/gen_base.h         |   4 -
 paddle/fluid/operators/jit/helper.cc          |  20 +++
 paddle/fluid/operators/jit/helper.h           |   4 +-
 paddle/fluid/operators/jit/kernel_base.h      |  54 ++++++-
 paddle/fluid/operators/jit/kernel_key.cc      |  38 +++++
 paddle/fluid/operators/jit/kernel_key.h       |   4 +
 .../fluid/operators/jit/refer/CMakeLists.txt  |   2 +
 paddle/fluid/operators/jit/refer/refer.cc     |   3 +
 paddle/fluid/operators/jit/refer/refer.h      |  89 ++++++++++++
 paddle/fluid/operators/jit/test.cc            | 137 ++++++++++++++++++
 paddle/fluid/operators/math/jit_kernel_impl.h |  39 -----
 .../fluid/operators/math/jit_kernel_refer.h   |  85 -----------
 13 files changed, 346 insertions(+), 138 deletions(-)
 create mode 100644 paddle/fluid/operators/jit/kernel_key.cc

diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc
index a8bf902963..310da0c76f 100644
--- a/paddle/fluid/operators/jit/gen_base.cc
+++ b/paddle/fluid/operators/jit/gen_base.cc
@@ -23,11 +23,6 @@ namespace paddle {
 namespace operators {
 namespace jit {
 
-template <>
-size_t JitCodeKey<int>(int d) {
-  return d;
-}
-
 // refer do not need useme, it would be the last one.
 void GenBase::dumpCode(const unsigned char* code) const {
   if (code) {
diff --git a/paddle/fluid/operators/jit/gen_base.h b/paddle/fluid/operators/jit/gen_base.h
index 586f4389c0..48855abd26 100644
--- a/paddle/fluid/operators/jit/gen_base.h
+++ b/paddle/fluid/operators/jit/gen_base.h
@@ -43,10 +43,6 @@ class GenBase : public Kernel {
   void dumpCode(const unsigned char* code) const;
 };
 
-// Every JitCode should have a method to get the key from attribution
-template <typename Attr>
-size_t JitCodeKey(Attr attr);
-
 // Creator is used to creat the jitcode and save in pool.
 // Every JitCode should have one creator.
 class GenCreator {
diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc
index c010b64c9c..d6fa4891e3 100644
--- a/paddle/fluid/operators/jit/helper.cc
+++ b/paddle/fluid/operators/jit/helper.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/helper.h"
+#include <algorithm>  // tolower
 #include "paddle/fluid/platform/enforce.h"
 
 namespace paddle {
@@ -36,6 +37,8 @@ const char* to_string(KernelType kt) {
     ONE_CASE(vexp);
     ONE_CASE(vsigmoid);
     ONE_CASE(vtanh);
+    ONE_CASE(lstmctht);
+    ONE_CASE(lstmc1h1);
     default:
       PADDLE_THROW("Not support type: %d", kt);
       return "NOT JITKernel";
@@ -44,6 +47,23 @@ const char* to_string(KernelType kt) {
 }
 #undef ONE_CASE
 
+KernelType to_kerneltype(const std::string& act) {
+  std::string lower = act;
+  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
+  if (lower == "relu" || lower == "vrelu") {
+    return vrelu;
+  } else if (lower == "identity" || lower == "videntity" || lower == "") {
+    return videntity;
+  } else if (lower == "exp" || lower == "vexp") {
+    return vexp;
+  } else if (lower == "sigmoid" || lower == "vsigmoid") {
+    return vsigmoid;
+  } else if (lower == "tanh" || lower == "vtanh") {
+    return vtanh;
+  }
+  return non_kernel;
+}
+
 }  // namespace jit
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h
index 053e5ed079..302e70caa7 100644
--- a/paddle/fluid/operators/jit/helper.h
+++ b/paddle/fluid/operators/jit/helper.h
@@ -14,9 +14,7 @@
 
 #pragma once
 
-#include <memory>  // for unique_ptr
 #include <string>
-#include <unordered_map>
 #include <vector>
 #include "paddle/fluid/operators/jit/gen_base.h"
 #include "paddle/fluid/operators/jit/kernel_base.h"
@@ -124,6 +122,8 @@ typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
 
 const char* to_string(KernelType kt);
 
+KernelType to_kerneltype(const std::string& act);
+
 }  // namespace jit
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h
index 29b881b754..3ab0194ce2 100644
--- a/paddle/fluid/operators/jit/kernel_base.h
+++ b/paddle/fluid/operators/jit/kernel_base.h
@@ -20,8 +20,9 @@ namespace operators {
 namespace jit {
 
 typedef enum {
-  vmul = 0,
-  vadd = 1,
+  non_kernel = 0,
+  vmul = 1,
+  vadd = 2,
   vaddrelu,
   vsub,
   vscal,
@@ -30,7 +31,9 @@ typedef enum {
   videntity,
   vexp,
   vsigmoid,
-  vtanh
+  vtanh,
+  lstmctht,
+  lstmc1h1
 } KernelType;
 
 template <typename T>
@@ -50,6 +53,51 @@ struct XYNTuples {
   typedef void (*func_type)(const T*, T*, int);
 };
 
+typedef struct {
+  void* gates;  // gates: x_ch, x_ih, x_fh, x_oh
+  const void* ct_1;
+  void* ct;
+  void* ht;
+  /* weight_peephole and checked data are only used in peephole*/
+  const void* wp{nullptr};  //  W_ic, W_fc, W_oc
+  void* checked{nullptr};   // size: 2 * d
+} lstm_t;
+
+typedef struct {
+  void* gates;  // gates: {x_update, x_reset; x_state}
+  const void* ht_1;
+  void* ht;
+} gru_t;
+
+struct rnn_attr_s {
+  int d;
+  KernelType act_gate, act_cand;
+  rnn_attr_s() = default;
+  rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
+      : d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
+};
+
+struct lstm_attr_s : public rnn_attr_s {
+  bool use_peephole;
+  KernelType act_cell;
+  lstm_attr_s() = default;
+  lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
+              KernelType _act_cell, bool _use_peephole = false)
+      : rnn_attr_s(_d, _act_gate, _act_cand),
+        use_peephole(_use_peephole),
+        act_cell(_act_cell) {}
+};
+
+typedef struct rnn_attr_s gru_attr_t;
+typedef struct lstm_attr_s lstm_attr_t;
+
+template <typename T>
+struct LSTMTuples {
+  typedef T data_type;
+  typedef lstm_attr_t attr_type;
+  typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
+};
+
 // Just for adding to kernel pool without template
 class Kernel {
  public:
diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc
new file mode 100644
index 0000000000..7a9ae81f89
--- /dev/null
+++ b/paddle/fluid/operators/jit/kernel_key.cc
@@ -0,0 +1,38 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. */
+
+#include "paddle/fluid/operators/jit/kernel_key.h"
+
+namespace paddle {
+namespace operators {
+namespace jit {
+
+template <>
+size_t JitCodeKey<int>(const int& d) {
+  return d;
+}
+
+template <>
+size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
+  constexpr int act_type_shift = 3;  // suppot 2^3 act types
+  size_t key = attr.d;
+  int gate_key = static_cast<int>(attr.act_gate) << 1;
+  int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
+  int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
+  return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
+         attr.use_peephole;
+}
+}  // namespace jit
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/jit/kernel_key.h b/paddle/fluid/operators/jit/kernel_key.h
index af9df77337..611a0210d6 100644
--- a/paddle/fluid/operators/jit/kernel_key.h
+++ b/paddle/fluid/operators/jit/kernel_key.h
@@ -44,6 +44,10 @@ struct KernelKey {
   bool operator!=(const KernelKey& o) const { return !(*this == o); }
 };
 
+// Every JitCode should have a method to get the key from attribution
+template <typename Attr>
+size_t JitCodeKey(const Attr& attr);
+
 }  // namespace jit
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt
index dc07ddb914..e30923c4fd 100644
--- a/paddle/fluid/operators/jit/refer/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt
@@ -18,3 +18,5 @@ USE_JITKERNEL_REFER(videntity)
 USE_JITKERNEL_REFER(vexp)
 USE_JITKERNEL_REFER(vsigmoid)
 USE_JITKERNEL_REFER(vtanh)
+USE_JITKERNEL_REFER(lstmctht)
+USE_JITKERNEL_REFER(lstmc1h1)
diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc
index f716ca89c5..59b3ce5248 100644
--- a/paddle/fluid/operators/jit/refer/refer.cc
+++ b/paddle/fluid/operators/jit/refer/refer.cc
@@ -35,4 +35,7 @@ REGISTER_REFER_KERNEL(vexp, VExp);
 REGISTER_REFER_KERNEL(vsigmoid, VSigmoid);
 REGISTER_REFER_KERNEL(vtanh, VTanh);
 
+REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt);
+REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1);
+
 #undef REGISTER_REFER_KERNEL
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index 7ef60a2d53..a93123df9d 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -110,6 +110,91 @@ void VTanh(const T* x, T* y, int n) {
   }
 }
 
+template <typename T>
+void (*getActFunc(KernelType type))(const T*, T*, int) {  // NOLINT
+  if (type == vsigmoid) {
+    return VSigmoid<T>;
+  } else if (type == vrelu) {
+    return VRelu<T>;
+  } else if (type == vtanh) {
+    return VTanh<T>;
+  } else if (type == videntity) {
+    return VIdentity<T>;
+  }
+  PADDLE_THROW("Not support type: %s", type);
+  return nullptr;
+}
+
+// compute ct and ht
+template <typename T>
+void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
+  T* gates = reinterpret_cast<T*>(step->gates);
+  const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
+  T* ct = reinterpret_cast<T*>(step->ct);
+  T* ht = reinterpret_cast<T*>(step->ht);
+  const T* wp = reinterpret_cast<const T*>(step->wp);
+  T* checked = reinterpret_cast<T*>(step->checked);
+  auto act_gate = getActFunc<T>(attr->act_gate);
+  auto act_cand = getActFunc<T>(attr->act_cand);
+  auto act_cell = getActFunc<T>(attr->act_cell);
+  int d = attr->d;
+  int d2 = d * 2;
+  int d3 = d * 3;
+  // gates: W_ch, W_ih, W_fh, W_oh
+  if (attr->use_peephole) {
+    VMul(wp, ct_1, checked, d);
+    VMul(wp + d, ct_1, checked + d, d);
+    VAdd(checked, gates + d, gates + d, d2);
+    act_gate(gates + d, gates + d, d2);
+  } else {
+    act_gate(gates + d, gates + d, d3);
+  }
+
+  // C_t = C_t-1 * fgated + cand_gated * igated
+  act_cand(gates, gates, d);
+  VMul(gates, gates + d, gates + d, d);
+  VMul(ct_1, gates + d2, gates + d2, d);
+  VAdd(gates + d, gates + d2, ct, d);
+
+  if (attr->use_peephole) {
+    // get ogated
+    VMul(wp + d2, ct, gates + d, d);
+    VAdd(gates + d, gates + d3, gates + d3, d);
+    act_gate(gates + d3, gates + d3, d);
+  }
+  // H_t = act_cell(C_t) * ogated
+  act_cell(ct, gates + d2, d);
+  VMul(gates + d2, gates + d3, ht, d);
+}
+
+// compute c1 and h1 without c0 or h0
+template <typename T>
+void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
+  T* gates = reinterpret_cast<T*>(step->gates);
+  T* ct = reinterpret_cast<T*>(step->ct);
+  T* ht = reinterpret_cast<T*>(step->ht);
+  auto act_gate = getActFunc<T>(attr->act_gate);
+  auto act_cand = getActFunc<T>(attr->act_cand);
+  auto act_cell = getActFunc<T>(attr->act_cell);
+  int d = attr->d;
+  int d2 = d * 2;
+  int d3 = d * 3;
+  /* C_t = igated * cgated*/
+  act_gate(gates + d, gates + d, d);
+  act_cand(gates, gates, d);
+  VMul(gates, gates + d, ct, d);
+  if (attr->use_peephole) {
+    // get outgated, put W_oc * C_t on igated
+    const T* wp = reinterpret_cast<const T*>(step->wp);
+    VMul(wp + d2, ct, gates + d, d);
+    VAdd(gates + d, gates + d3, gates + d3, d);
+  }
+  /* H_t = act_cell(C_t) * ogated */
+  act_gate(gates + d3, gates + d3, d);
+  act_cell(ct, gates + d2, d);
+  VMul(gates + d2, gates + d3, ht, d);
+}
+
 #define DECLARE_REFER_KERNEL(name, tuples)             \
   template <typename T>                                \
   class name##Kernel : public ReferKernel<tuples<T>> { \
@@ -134,6 +219,10 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples);
 DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
 DECLARE_REFER_KERNEL(VTanh, XYNTuples);
 
+// lstm_t* , const lstm_attr_t*
+DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
+DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
+
 #undef DECLARE_REFER_KERNEL
 
 }  // namespace refer
diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc
index 4c9b853b6e..03e56416b2 100644
--- a/paddle/fluid/operators/jit/test.cc
+++ b/paddle/fluid/operators/jit/test.cc
@@ -350,6 +350,143 @@ TEST(JITKernel, vtanh) {
   TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>();
 }
 
+template <typename T, typename KernelTuples>
+void TestLSTMFunc(const typename KernelTuples::func_type tgt,
+                  const std::vector<T>& xsrc, const std::vector<T>& wp,
+                  const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
+                  const std::vector<T>& ht_ref,
+                  const paddle::operators::jit::lstm_attr_t& attr) {
+  EXPECT_TRUE(tgt != nullptr);
+  EXPECT_EQ(ct_ref.size(), ht_ref.size());
+  EXPECT_EQ(ct_1.size(), ht_ref.size());
+  EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
+  EXPECT_EQ(wp.size(), 3 * ht_ref.size());
+
+  // x could be changed after compute, so copy to save src
+  int d = ht_ref.size();
+  std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
+  std::vector<T> checked(2 * d);
+  std::copy(xsrc.begin(), xsrc.end(), x.begin());
+
+  const T* ct_1_data = ct_1.data();
+  const T* wp_data = wp.data();
+  const T* ct_ref_data = ct_ref.data();
+  const T* ht_ref_data = ht_ref.data();
+  T* x_data = x.data();
+  T* ct_data = ct.data();
+  T* ht_data = ht.data();
+  T* checked_data = checked.data();
+
+  paddle::operators::jit::lstm_t step;
+  step.gates = x_data;
+  step.ct_1 = ct_1_data;
+  step.ct = ct_data;
+  step.ht = ht_data;
+  if (attr.use_peephole) {
+    step.wp = wp_data;
+    step.checked = checked_data;
+  }
+
+  tgt(&step, &attr);
+  ExpectEQ<T>(ct_data, ct_ref_data, d);
+  ExpectEQ<T>(ht_data, ht_ref_data, d);
+}
+
+template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
+void TestLSTMKernel() {
+  namespace jit = paddle::operators::jit;
+  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+  std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
+  for (int d : TestSizes()) {
+    for (bool use_peephole : {true, false}) {
+      for (auto& act_gate : all_acts) {
+        for (auto& act_cand : all_acts) {
+          for (auto& act_cell : all_acts) {
+            std::string info = act_gate + act_cand + act_cell +
+                               (use_peephole ? "peephole_" : "") + "size_" +
+                               std::to_string(d);
+            const jit::lstm_attr_t attr(
+                d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
+                jit::to_kerneltype(act_cell), use_peephole);
+            auto ref = jit::GetRefer<KT, jit::LSTMTuples<T>>();
+            EXPECT_TRUE(ref != nullptr);
+            std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
+            std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
+            RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f);
+            RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
+            RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
+            // x could be changed after compute, so copy to save src
+            std::vector<T> x(xsrc.size());
+            std::copy(xsrc.begin(), xsrc.end(), x.begin());
+            const T* ct_1_data = ct_1.data();
+            const T* wp_data = wp.data();
+            T* x_data = x.data();
+            T* checked_data = checked.data();
+            T* ct_ref_data = ct_ref.data();
+            T* ht_ref_data = ht_ref.data();
+            jit::lstm_t step;
+            step.gates = x_data;
+            step.ct_1 = ct_1_data;
+            step.ct = ct_ref_data;
+            step.ht = ht_ref_data;
+            if (use_peephole) {
+              step.wp = wp_data;
+              step.checked = checked_data;
+            }
+            ref(&step, &attr);
+
+            // test jitcode
+            auto jitcode =
+                jit::GetJitCode<KT, jit::LSTMTuples<T>, PlaceType>(attr);
+            if (jitcode) {
+              VLOG(10) << "Test Jitcode Kernel " << info;
+              TestLSTMFunc<T, jit::LSTMTuples<T>>(jitcode, xsrc, wp, ct_1,
+                                                  ct_ref, ht_ref, attr);
+            }
+
+            // test all impls in more
+            jit::KernelKey kkey(KT, PlaceType());
+            auto& pool = jit::KernelPool().Instance().AllKernels();
+            auto iter = pool.find(kkey);
+            if (iter != pool.end()) {
+              auto& impls = iter->second;
+              for (auto& impl : impls) {
+                auto i =
+                    dynamic_cast<const jit::KernelImpl<jit::LSTMTuples<T>>*>(
+                        impl.get());
+                if (i && i->UseMe(attr)) {
+                  auto more = i->GetFunc();
+                  VLOG(10) << "Test More Kernel " << info;
+                  TestLSTMFunc<T, jit::LSTMTuples<T>>(more, xsrc, wp, ct_1,
+                                                      ct_ref, ht_ref, attr);
+                }
+              }
+            }
+            // Test result from Get function
+            auto tgt = jit::Get<KT, jit::LSTMTuples<T>, PlaceType>(attr);
+            TestLSTMFunc<T, jit::LSTMTuples<T>>(tgt, xsrc, wp, ct_1, ct_ref,
+                                                ht_ref, attr);
+          }
+        }
+      }
+    }
+  }
+}
+
+TEST(JITKernel, lstmctht) {
+  namespace jit = paddle::operators::jit;
+  TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>();
+  TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>();
+}
+
+TEST(JITKernel, lstmc1h1) {
+  namespace jit = paddle::operators::jit;
+  TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>();
+  TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>();
+}
+
+// TODO(TJ): refine the tests template
+
 TEST(JITKernel, pool) {
   // TODO(TJ): add some test
 }
diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h
index ba5f20e533..025343dfad 100644
--- a/paddle/fluid/operators/math/jit_kernel_impl.h
+++ b/paddle/fluid/operators/math/jit_kernel_impl.h
@@ -28,45 +28,6 @@ namespace jitkernel {
 #define YMM_FLOAT_BLOCK 8
 #define ZMM_FLOAT_BLOCK 16
 
-typedef struct {
-  void* gates;  // gates: W_ch, W_ih, W_fh, W_oh
-  const void* ct_1;
-  void* ct;
-  void* ht;
-  /* weight_peephole and checked data are only used in peephole*/
-  const void* wp{nullptr};
-  void* checked{nullptr};
-} lstm_t;
-
-typedef struct {
-  void* gates;  // gates: {W_update, W_reset; W_state}
-  const void* ht_1;
-  void* ht;
-} gru_t;
-
-struct rnn_attr_s {
-  int d;
-  std::string act_gate, act_cand;
-  rnn_attr_s() = default;
-  rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
-      : d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
-};
-
-struct lstm_attr_s : public rnn_attr_s {
-  bool use_peephole;
-  std::string act_cell;
-  lstm_attr_s() = default;
-  lstm_attr_s(int _d, const std::string& _act_gate,
-              const std::string& _act_cand, const std::string& _act_cell,
-              bool _use_peephole = false)
-      : rnn_attr_s(_d, _act_gate, _act_cand),
-        use_peephole(_use_peephole),
-        act_cell(_act_cell) {}
-};
-
-typedef struct rnn_attr_s gru_attr_t;
-typedef struct lstm_attr_s lstm_attr_t;
-
 }  // namespace jitkernel
 }  // namespace math
 }  // namespace operators
diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h
index a03e851de5..122cbcb0d6 100644
--- a/paddle/fluid/operators/math/jit_kernel_refer.h
+++ b/paddle/fluid/operators/math/jit_kernel_refer.h
@@ -24,91 +24,6 @@ namespace math {
 namespace jitkernel {
 namespace refer {
 
-template <typename T>
-void (*getActFunc(const std::string& type))(const T*, T*, int) {  // NOLINT
-  if (type == "sigmoid") {
-    return VSigmoid<T>;
-  } else if (type == "relu") {
-    return VRelu<T>;
-  } else if (type == "tanh") {
-    return VTanh<T>;
-  } else if (type == "identity" || type == "") {
-    return VIdentity<T>;
-  }
-  PADDLE_THROW("Not support type: %s", type);
-  return nullptr;
-}
-
-// compute ct and ht
-template <typename T>
-void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
-  T* gates = reinterpret_cast<T*>(step->gates);
-  const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
-  T* ct = reinterpret_cast<T*>(step->ct);
-  T* ht = reinterpret_cast<T*>(step->ht);
-  const T* wp = reinterpret_cast<const T*>(step->wp);
-  T* checked = reinterpret_cast<T*>(step->checked);
-  auto act_gate = getActFunc<T>(attr->act_gate);
-  auto act_cand = getActFunc<T>(attr->act_cand);
-  auto act_cell = getActFunc<T>(attr->act_cell);
-  int d = attr->d;
-  int d2 = d * 2;
-  int d3 = d * 3;
-  // gates: W_ch, W_ih, W_fh, W_oh
-  if (attr->use_peephole) {
-    VMul(wp, ct_1, checked, d);
-    VMul(wp + d, ct_1, checked + d, d);
-    VAdd(checked, gates + d, gates + d, d2);
-    act_gate(gates + d, gates + d, d2);
-  } else {
-    act_gate(gates + d, gates + d, d3);
-  }
-
-  // C_t = C_t-1 * fgated + cand_gated * igated
-  act_cand(gates, gates, d);
-  VMul(gates, gates + d, gates + d, d);
-  VMul(ct_1, gates + d2, gates + d2, d);
-  VAdd(gates + d, gates + d2, ct, d);
-
-  if (attr->use_peephole) {
-    // get ogated
-    VMul(wp + d2, ct, gates + d, d);
-    VAdd(gates + d, gates + d3, gates + d3, d);
-    act_gate(gates + d3, gates + d3, d);
-  }
-  // H_t = act_cell(C_t) * ogated
-  act_cell(ct, gates + d2, d);
-  VMul(gates + d2, gates + d3, ht, d);
-}
-
-// compute c1 and h1 without c0 or h0
-template <typename T>
-void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
-  T* gates = reinterpret_cast<T*>(step->gates);
-  T* ct = reinterpret_cast<T*>(step->ct);
-  T* ht = reinterpret_cast<T*>(step->ht);
-  auto act_gate = getActFunc<T>(attr->act_gate);
-  auto act_cand = getActFunc<T>(attr->act_cand);
-  auto act_cell = getActFunc<T>(attr->act_cell);
-  int d = attr->d;
-  int d2 = d * 2;
-  int d3 = d * 3;
-  /* C_t = igated * cgated*/
-  act_gate(gates + d, gates + d, d);
-  act_cand(gates, gates, d);
-  VMul(gates, gates + d, ct, d);
-  if (attr->use_peephole) {
-    // get outgated, put W_oc * C_t on igated
-    const T* wp = reinterpret_cast<const T*>(step->wp);
-    VMul(wp + d2, ct, gates + d, d);
-    VAdd(gates + d, gates + d3, gates + d3, d);
-  }
-  /* H_t = act_cell(C_t) * ogated */
-  act_gate(gates + d3, gates + d3, d);
-  act_cell(ct, gates + d2, d);
-  VMul(gates + d2, gates + d3, ht, d);
-}
-
 // compute h1 without h0
 template <typename T>
 void GRUH1(gru_t* step, const gru_attr_t* attr) {