From cf8c8e72bdd1e6c76aeeee85050718710e510490 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sun, 30 Sep 2018 00:02:31 +0800
Subject: [PATCH] add vtanh and unit test

---
 paddle/fluid/operators/math/jit_kernel.h      |   4 +-
 paddle/fluid/operators/math/jit_kernel_exp.cc | 113 ++++++++++++++++++
 .../fluid/operators/math/jit_kernel_test.cc   |  66 ++++++++++
 3 files changed, 180 insertions(+), 3 deletions(-)

diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h
index 32944ae82c..eaf5fd0a87 100644
--- a/paddle/fluid/operators/math/jit_kernel.h
+++ b/paddle/fluid/operators/math/jit_kernel.h
@@ -28,13 +28,11 @@ namespace jitkernel {
 
 #define SIGMOID_THRESHOLD_MIN -40.0
 #define SIGMOID_THRESHOLD_MAX 13.0
+#define EXP_MAX_INPUT 40.0
 
 #define AVX_FLOAT_BLOCK 8
-#define AVX_DOUBLE_BLOCK 4
 #define AVX2_FLOAT_BLOCK 8
-#define AVX2_DOUBLE_BLOCK 4
 #define AVX512_FLOAT_BLOCK 16
-#define AVX512_DOUBLE_BLOCK 8
 
 typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
 
diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc
index 0717c2aeeb..da0a71be28 100644
--- a/paddle/fluid/operators/math/jit_kernel_exp.cc
+++ b/paddle/fluid/operators/math/jit_kernel_exp.cc
@@ -235,6 +235,7 @@ INTRI16_FLOAT(jit::avx512f);
 #undef INTRI16_FLOAT
 #undef INTRI_GT8LT16_FLOAT
 #undef INTRI_GT16_FLOAT
+#undef INTRI_VSIGMOID
 
 #define JITKERNEL_NEW_ACT_IMPL(ker, dtype, isa, k) \
   p = std::dynamic_pointer_cast<ker<dtype>>(       \
@@ -243,6 +244,118 @@ INTRI16_FLOAT(jit::avx512f);
 REGISTER_JITKERNEL_ARGS(vsigmoid, VSigmoidKernel, JITKERNEL_DECLARE,
                         JITKERNEL_KEY, JITKERNEL_NEW_ACT_IMPL);
 
+/* VTanh JitKernel */
+template <typename T, jit::cpu_isa_t isa, jit_block>
+class VTanhKernelImpl : public VTanhKernel<T> {
+ public:
+  explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
+    vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
+    vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
+    vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
+  }
+  void Compute(const int n, const T* x, T* y) const override {
+    vscal_->Compute(n, static_cast<T>(2), x, y);
+    vsigmoid_->Compute(n, y, y);
+    vscal_->Compute(n, static_cast<T>(2), y);
+    vaddbias_->Compute(n, static_cast<T>(-1), y, y);
+  }
+
+ private:
+  std::shared_ptr<const VScalKernel<T>> vscal_;
+  std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
+  std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
+};
+
+#define INTRI_VTANH(tmp)                                   \
+  tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp);         \
+  tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
+  tmp = detail::Exp(tmp);                                  \
+  tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);          \
+  tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp);          \
+  tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
+
+#define INTRI8_FLOAT(isa)                                                      \
+  template <>                                                                  \
+  void VTanhKernelImpl<float, isa, kEQ8>::Compute(const int n, const float* x, \
+                                                  float* y) const {            \
+    __m256 tmp = _mm256_loadu_ps(x);                                           \
+    INTRI_VTANH(tmp);                                                          \
+    _mm256_storeu_ps(y, tmp);                                                  \
+  }
+
+#define INTRI16_FLOAT(isa)                           \
+  template <>                                        \
+  void VTanhKernelImpl<float, isa, kEQ16>::Compute(  \
+      const int n, const float* x, float* y) const { \
+    __m256 tmp0 = _mm256_loadu_ps(x);                \
+    __m256 tmp1 = _mm256_loadu_ps(x + 8);            \
+    INTRI_VTANH(tmp0);                               \
+    INTRI_VTANH(tmp1);                               \
+    _mm256_storeu_ps(y, tmp0);                       \
+    _mm256_storeu_ps(y + 8, tmp1);                   \
+  }
+
+#define INTRI_GT8LT16_FLOAT(isa)                       \
+  template <>                                          \
+  void VTanhKernelImpl<float, isa, kGT8LT16>::Compute( \
+      const int n, const float* x, float* y) const {   \
+    __m256 tmp = _mm256_loadu_ps(x);                   \
+    INTRI_VTANH(tmp);                                  \
+    _mm256_storeu_ps(y, tmp);                          \
+    x += AVX_FLOAT_BLOCK;                              \
+    y += AVX_FLOAT_BLOCK;                              \
+    const int rest = n - AVX_FLOAT_BLOCK;              \
+    vscal_->Compute(rest, 2.f, x, y);                  \
+    vsigmoid_->Compute(rest, y, y);                    \
+    vscal_->Compute(rest, 2.f, y);                     \
+    vaddbias_->Compute(rest, -1.f, y, y);              \
+  }
+
+#define INTRI_GT16_FLOAT(isa)                        \
+  template <>                                        \
+  void VTanhKernelImpl<float, isa, kGT16>::Compute(  \
+      const int n, const float* x, float* y) const { \
+    const int rest = n % AVX_FLOAT_BLOCK;            \
+    const int end = n - rest;                        \
+    for (int i = 0; i < end; i += AVX_FLOAT_BLOCK) { \
+      __m256 tmp = _mm256_loadu_ps(x + i);           \
+      INTRI_VTANH(tmp);                              \
+      _mm256_storeu_ps(y + i, tmp);                  \
+    }                                                \
+    x += end;                                        \
+    y += end;                                        \
+    vscal_->Compute(rest, 2.f, x, y);                \
+    vsigmoid_->Compute(rest, y, y);                  \
+    vscal_->Compute(rest, 2.f, y);                   \
+    vaddbias_->Compute(rest, -1.f, y, y);            \
+  }
+
+#ifdef __AVX__
+INTRI8_FLOAT(jit::avx);
+INTRI16_FLOAT(jit::avx);
+INTRI_GT8LT16_FLOAT(jit::avx);
+INTRI_GT16_FLOAT(jit::avx);
+#endif
+#ifdef __AVX2__
+INTRI8_FLOAT(jit::avx2);
+INTRI16_FLOAT(jit::avx2);
+// maybe use avx at gt8lt16 and gt16
+#endif
+#ifdef __AVX512F__
+INTRI8_FLOAT(jit::avx512f);
+INTRI16_FLOAT(jit::avx512f);
+// maybe use avx at gt8lt16 and gt16
+#endif
+
+#undef INTRI8_FLOAT
+#undef INTRI16_FLOAT
+#undef INTRI_GT8LT16_FLOAT
+#undef INTRI_GT16_FLOAT
+#undef INTRI_VTANH
+
+REGISTER_JITKERNEL_ARGS(vtanh, VTanhKernel, JITKERNEL_DECLARE, JITKERNEL_KEY,
+                        JITKERNEL_NEW_ACT_IMPL);
+
 #undef JITKERNEL_NEW_ACT_IMPL
 
 }  // namespace jitkernel
diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc
index 7c41787141..3aadc6ef44 100644
--- a/paddle/fluid/operators/math/jit_kernel_test.cc
+++ b/paddle/fluid/operators/math/jit_kernel_test.cc
@@ -208,6 +208,72 @@ TEST(JitKernel, vsigmoid) {
   }
 }
 
+inline float _tanh(float x) { return 2.f * _sigmoid(2.f * x) - 1.f; }
+
+void vtanh_ref(const int n, const float* x, float* y) {
+  for (int i = 0; i < n; ++i) {
+    y[i] = _tanh(x[i]);
+  }
+}
+
+void vtanh_better(
+    const std::shared_ptr<
+        const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal,
+    const std::shared_ptr<
+        const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
+        vsigmoid,
+    const std::shared_ptr<
+        const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
+        vaddbias,
+    const int n, const float* x, float* y) {
+  vscal->Compute(n, 2.f, x, y);
+  vsigmoid->Compute(n, y, y);
+  vscal->Compute(n, 2.f, y);
+  vaddbias->Compute(n, -1.f, y, y);
+}
+
+TEST(JitKernel, vtanh) {
+  namespace jit = paddle::operators::math::jitkernel;
+  for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
+    std::vector<float> x(d);
+    std::vector<float> zref(d), ztgt(d);
+    RandomVec<float>(d, x.data(), -2.f, 2.f);
+    const auto& ker =
+        jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
+    const auto& vscal =
+        jit::KernelPool::Instance().template Get<jit::VScalKernel<float>>(d);
+    const auto& vsigmoid =
+        jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(d);
+    const auto& vaddbias =
+        jit::KernelPool::Instance().template Get<jit::VAddBiasKernel<float>>(d);
+    const float* x_data = x.data();
+    float* ztgt_data = ztgt.data();
+    float* zref_data = zref.data();
+    auto tmkls = GetCurrentUS();
+    for (int i = 0; i < repeat; ++i) {
+      vtanh_better(vscal, vsigmoid, vaddbias, d, x_data, zref_data);
+    }
+    auto tmkle = GetCurrentUS();
+    auto trefs = GetCurrentUS();
+    for (int i = 0; i < repeat; ++i) {
+      vtanh_ref(d, x_data, zref_data);
+    }
+    auto trefe = GetCurrentUS();
+    auto ttgts = GetCurrentUS();
+    for (int i = 0; i < repeat; ++i) {
+      ker->Compute(d, x_data, ztgt_data);
+    }
+    auto ttgte = GetCurrentUS();
+
+    VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
+            << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
+            << " us, tgt takes: " << (ttgte - ttgts) / repeat;
+    for (int i = 0; i < d; ++i) {
+      EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
+    }
+  }
+}
+
 void vscal_ref(const int n, const float a, const float* x, float* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = a * x[i];