From 9131a35676ce36e0aa943567c60fa39a024ef6c9 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Tue, 9 Oct 2018 22:38:26 +0800
Subject: [PATCH] replace the lstm compute with jitkernel

test=develop
---
 paddle/fluid/operators/CMakeLists.txt         |  2 +-
 paddle/fluid/operators/fusion_lstm_op.cc      | 50 ++++++---------
 paddle/fluid/operators/math/CMakeLists.txt    |  8 +--
 .../fluid/operators/math/cpu_lstm_compute.cc  | 43 -------------
 .../fluid/operators/math/cpu_lstm_compute.h   | 64 -------------------
 5 files changed, 22 insertions(+), 145 deletions(-)
 delete mode 100644 paddle/fluid/operators/math/cpu_lstm_compute.cc
 delete mode 100644 paddle/fluid/operators/math/cpu_lstm_compute.h

diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index 2ef13b72ed..4d8dd0df19 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -299,7 +299,7 @@ op_library(flatten_op DEPS reshape_op)
 op_library(sequence_pad_op DEPS sequence_padding)
 op_library(unstack_op DEPS stack_op)
 op_library(fake_quantize_op DEPS memory)
-op_library(fusion_lstm_op DEPS cpu_lstm_compute)
+op_library(fusion_lstm_op DEPS jit_kernel)
 if (WITH_GPU)
     op_library(conv_op DEPS vol2col depthwise_conv im2col)
     op_library(layer_norm_op DEPS cub)
diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index ae1f6d8e48..abaa9237c0 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -15,9 +15,9 @@ limitations under the License. */
 #include "paddle/fluid/operators/fusion_lstm_op.h"
 #include <string>
 #include "paddle/fluid/operators/math/blas.h"
-#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
 #include "paddle/fluid/operators/math/cpu_vec.h"
 #include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/jit_kernel.h"
 #include "paddle/fluid/operators/math/sequence2batch.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -309,11 +309,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
   act_gate(D, gates + D3, gates + D3);              \
   GET_Ht(ct, gates, ht)
 
-#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
-  act_gate(D3, gates + D, gates + D);     \
-  GET_Ct(ct_1, gates, ct);                \
-  GET_Ht(ct, gates, ht)
-
 #define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht)        \
   /* get fgated and igated*/                              \
   blas.VMUL(D, wc_data, ct_1, checked_cell_data);         \
@@ -403,22 +398,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
         }
       }
     } else {
-      // TODO(TJ): unly workaround, clean me
-      std::function<void(T*, const T*, T*, T*)> compute_ctht;
-      if (platform::jit::MayIUse(platform::jit::avx) &&
-          act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
-          act_cell_str == "tanh" && D == 8) {
-        compute_ctht = math::lstm_compute_ctht<T>;
-      } else {
-        compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
-          COMPUTE_CtHt(gates, ct_1, ct, ht);
-        };
-      }
+      const auto& ker =
+          math::jitkernel::KernelPool::Instance()
+              .template Get<math::jitkernel::LSTMKernel<T>, int,
+                            const std::string&, const std::string&,
+                            const std::string&>(D, act_gate_str, act_cand_str,
+                                                act_cell_str);
+
       for (int i = 0; i < N; ++i) {
         PROCESS_H0C0
         for (int step = tstart; step < seq_len; ++step) {
           GEMM_WH_ADDON(1, prev_h_data, xx_data);
-          compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data);
+          ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data);
           MOVE_ONE_STEP;
         }
       }
@@ -552,24 +543,20 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
         MOVE_ONE_STEP;
       }
     } else {
-      // TODO(TJ): unly workaround, clean me
-      std::function<void(T*, const T*, T*, T*)> compute_ctht;
-      if (platform::jit::MayIUse(platform::jit::avx) &&
-          act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
-          act_cell_str == "tanh" && D == 8) {
-        compute_ctht = math::lstm_compute_ctht<T>;
-      } else {
-        compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
-          COMPUTE_CtHt(gates, ct_1, ct, ht);
-        };
-      }
+      const auto& ker =
+          math::jitkernel::KernelPool::Instance()
+              .template Get<math::jitkernel::LSTMKernel<T>, int,
+                            const std::string&, const std::string&,
+                            const std::string&>(D, act_gate_str, act_cand_str,
+                                                act_cell_str);
+
       for (int step = tstart; step < max_seq_len; ++step) {
         const int cur_bs = batch_starts[step + 1] - batch_starts[step];
         GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
         DEFINE_CUR;
         for (int i = 0; i < cur_bs; ++i) {
-          compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data,
-                       cur_h_out_data);
+          ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
+                           cur_h_out_data);
           MOVE_ONE_BATCH;
         }
         MOVE_ONE_STEP;
@@ -595,7 +582,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
   }
 
 #undef COMPUTE_CtHt_PEEPHOLE
-#undef COMPUTE_CtHt
 #undef GET_Ct_NOH0C0
 #undef COMPUTE_CtHt_NOH0C0
 #undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt
index 16e1dc40f1..b859636f76 100644
--- a/paddle/fluid/operators/math/CMakeLists.txt
+++ b/paddle/fluid/operators/math/CMakeLists.txt
@@ -45,8 +45,6 @@ math_library(im2col)
 if (NOT WIN32) # windows do not support avx functions yet.
 math_library(gru_compute DEPS activation_functions math_function)
 math_library(lstm_compute DEPS activation_functions)
-# TODO(TJ): ugly workaround, clean me
-cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info)
 endif (NOT WIN32)
 
 cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
@@ -76,7 +74,7 @@ if(WITH_GPU)
 endif()
 cc_test(concat_test SRCS concat_test.cc DEPS concat)
 cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
-cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions)
-cc_library(jit_kernel_lstm SRCS jit_kernel_lstm.cc DEPS cpu_info cblas activation_functions)
-cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc DEPS cpu_info cblas jit_kernel_exp jit_kernel_lstm)
+cc_library(jit_kernel 
+    SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
+    DEPS cpu_info cblas activation_functions)
 cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc
deleted file mode 100644
index e96d187933..0000000000
--- a/paddle/fluid/operators/math/cpu_lstm_compute.cc
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
-
-namespace paddle {
-namespace operators {
-namespace math {
-#ifdef __AVX__
-template <>
-void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
-                              float* ht) {
-  namespace act = detail::forward::avx;
-  // gates: W_ch, W_ih, W_fh, W_oh
-  __m256 c, i, f, o;
-  c = _mm256_loadu_ps(gates);
-  i = _mm256_loadu_ps(gates + 8);
-  f = _mm256_loadu_ps(gates + 16);
-  o = _mm256_loadu_ps(gates + 24);
-
-  /* C_t = C_t-1 * fgated + cand_gated * igated*/
-  c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
-  i = _mm256_loadu_ps(ct_1);
-  f = _mm256_mul_ps(i, act::Sigmoid(f));
-  f = _mm256_add_ps(c, f);
-  _mm256_storeu_ps(ct, f);
-
-  /* H_t = act_cell(C_t) * ogated */
-  o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
-  _mm256_storeu_ps(ht, o);
-}
-#endif
-}  // namespace math
-}  // namespace operators
-}  // namespace paddle
diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h
deleted file mode 100644
index 169a9e4b47..0000000000
--- a/paddle/fluid/operators/math/cpu_lstm_compute.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-#include <string>
-#include "paddle/fluid/operators/math/cpu_vec.h"
-#include "paddle/fluid/platform/cpu_info.h"
-#ifdef __AVX__
-#include <immintrin.h>
-#endif
-
-namespace paddle {
-namespace operators {
-namespace math {
-
-// TODO(TJ): ugly workaround, clean me
-template <typename T>
-void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
-  // gates: W_ch, W_ih, W_fh, W_oh
-  vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
-  vec_tanh<T, platform::jit::avx>(8, gates, gates);
-  const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
-  const T min = SIGMOID_THRESHOLD_MIN;
-  const T max = SIGMOID_THRESHOLD_MAX;
-  for (int d = 0; d < 8; ++d) {
-    // C_t = C_t-1 * fgated + cand_gated * igated
-    ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
-    // H_t = act_cell(C_t) * ogated
-    T tmp = ct[d] * 2;
-    tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
-    vec_exp<T>(1, &tmp, &tmp);
-    tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
-    ht[d] = tmp * o[d];
-  }
-}
-
-#ifdef __AVX__
-namespace detail {
-namespace forward {
-namespace avx {
-__m256 Sigmoid(const __m256 a);
-__m256 Tanh(const __m256 a);
-
-}  // namespace avx
-}  // namespace forward
-}  // namespace detail
-
-template <>
-void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
-                              float* ht);
-
-#endif
-
-}  // namespace math
-}  // namespace operators
-}  // namespace paddle