From 3e01a4048f28ad5cf4b33fb808b07965d9e7ff5d Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Fri, 28 Dec 2018 16:34:13 +0000
Subject: [PATCH 01/28] add refer seqpool jitkernel

---
 paddle/fluid/operators/jit/kernel_base.h      | 20 +++++++++++++++++++
 paddle/fluid/operators/jit/kernel_key.cc      |  6 ++++++
 .../fluid/operators/jit/refer/CMakeLists.txt  |  1 +
 paddle/fluid/operators/jit/refer/refer.cc     |  2 ++
 paddle/fluid/operators/jit/refer/refer.h      | 16 +++++++++++++++
 5 files changed, 45 insertions(+)

diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h
index b4a2d5d473..8f13fbb16e 100644
--- a/paddle/fluid/operators/jit/kernel_base.h
+++ b/paddle/fluid/operators/jit/kernel_base.h
@@ -41,6 +41,7 @@ typedef enum {
   kCRFDecoding,
   kLayerNorm,
   kNCHW16CMulNC,
+  kSeqPool,
 } KernelType;
 
 template <typename T>
@@ -112,6 +113,25 @@ struct GRUTuples {
   typedef void (*func_type)(gru_t*, const gru_attr_t*);
 };
 
+typedef enum {
+  non = 0,
+  sum,
+  avg,
+  sqrt,
+} SeqPoolType;
+
+typedef struct {
+  int h, w;
+  SeqPoolType type;
+} seq_pool_attr_t;
+
+template <typename T>
+struct SeqPoolTuples {
+  typedef T data_type;
+  typedef seq_pool_attr_t attr_type;
+  typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
+};
+
 template <typename T>
 struct CRFDecodingTuples {
   typedef T data_type;
diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc
index 4e6a19f04f..6b0025a75a 100644
--- a/paddle/fluid/operators/jit/kernel_key.cc
+++ b/paddle/fluid/operators/jit/kernel_key.cc
@@ -42,6 +42,12 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
          (static_cast<int>(attr.act_cand) << act_type_shift);
 }
 
+template <>
+size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
+  size_t key = static_cast<size_t>(attr.type);
+  return key + (attr.w << act_type_shift);
+}
+
 }  // namespace jit
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt
index 07497b7320..0f626bb3bf 100644
--- a/paddle/fluid/operators/jit/refer/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt
@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2)
 USE_JITKERNEL_REFER(kCRFDecoding)
 USE_JITKERNEL_REFER(kLayerNorm)
 USE_JITKERNEL_REFER(kNCHW16CMulNC)
+USE_JITKERNEL_REFER(kSeqPool)
diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc
index d196266326..85381daa47 100644
--- a/paddle/fluid/operators/jit/refer/refer.cc
+++ b/paddle/fluid/operators/jit/refer/refer.cc
@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
 
 REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
 
+REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
+
 #undef REGISTER_REFER_KERNEL
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index 0fd1b89dfd..52fe2de02a 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
   }
 }
 
+template <typename T>
+void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
+  PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet");
+  for (int w = 0; w < attr->w; ++w) {
+    const T* src = x + w;
+    T* dst = y + w;
+    *dst = static_cast<T>(0);
+    for (int h = 0; h < attr->h; ++h) {
+      *dst = *dst + *src;
+      src += attr->w;
+    }
+  }
+}
+
 #define DECLARE_REFER_KERNEL(name, tuples)             \
   template <typename T>                                \
   class name##Kernel : public ReferKernel<tuples<T>> { \
@@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
 
 DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
 
+DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
+
 #undef DECLARE_REFER_KERNEL
 
 }  // namespace refer

From e58a569c6cdb8ab66c7dff69395518cee224fe67 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Fri, 28 Dec 2018 16:35:00 +0000
Subject: [PATCH 02/28] use seqpool jitkernel

---
 paddle/fluid/operators/math/CMakeLists.txt    |  2 +-
 .../fluid/operators/math/sequence_pooling.cc  | 32 ++++++++++++-------
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt
index ea6aebd291..600ab14d37 100644
--- a/paddle/fluid/operators/math/CMakeLists.txt
+++ b/paddle/fluid/operators/math/CMakeLists.txt
@@ -51,7 +51,7 @@ math_library(pooling)
 math_library(selected_rows_functor DEPS selected_rows math_function blas)
 math_library(sequence2batch)
 math_library(sequence_padding)
-math_library(sequence_pooling DEPS math_function)
+math_library(sequence_pooling DEPS math_function jit_kernel_helper)
 math_library(sequence_scale)
 math_library(softmax DEPS math_function)
 
diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index 6d491dbf1e..23dc516933 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -14,6 +14,7 @@ limitations under the License. */
 
 #include <string>
 
+#include "paddle/fluid/operators/jit/kernels.h"
 #include "paddle/fluid/operators/math/blas.h"
 #include "paddle/fluid/operators/math/math_function.h"
 #include "paddle/fluid/operators/math/sequence_pooling.h"
@@ -239,15 +240,33 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
       last_pool(context, input, output);
       return;
     }
-
     if (pooltype == "FIRST") {
       math::FirstSeqPoolFunctor<T> first_pool;
       first_pool(context, input, output);
       return;
     }
+
     auto lod = input.lod()[0];
+    if (pooltype == "SUM") {
+      auto place = context.GetPlace();
+      PADDLE_ENFORCE(platform::is_cpu_place(place));
+      const T* src = input.data<T>();
+      T* dst = output->mutable_data<T>(place);
+      jit::seq_pool_attr_t attr;
+      attr.w = input.numel() / input.dims()[0];
+      attr.type = jit::SeqPoolType::sum;
+      auto seqpool =
+          jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
+              attr);
+      for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
+        attr.h = static_cast<int>(lod[i + 1] - lod[i]);
+        seqpool(src, dst, &attr);
+        dst += attr.w;
+        src += attr.h * attr.w;
+      }
+      return;
+    }
     auto& place = *context.eigen_device();
-    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
     for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
       Tensor in_t =
           input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
@@ -258,15 +277,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
       auto out_e = EigenVector<T>::Flatten(out_t);
       if (pooltype == "AVERAGE") {
         out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
-      } else if (pooltype == "SUM") {
-        if (h > 0) {
-          const T* in_data = in_t.data<T>();
-          T* out_data = out_t.mutable_data<T>(context.GetPlace());
-          blas.VCOPY(w, in_data, out_data);
-          for (int64_t r = 1; r != h; ++r) {
-            blas.AXPY(w, 1., in_data + r * w, out_data);
-          }
-        }
       } else if (pooltype == "SQRT") {
         out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
                               std::sqrt(static_cast<T>(h));

From 142bb417483f9e0e71a26d24d30eb01c6d2f7754 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sat, 29 Dec 2018 05:13:08 +0000
Subject: [PATCH 03/28] add seqpool jitkernel test and benchmark

---
 paddle/fluid/operators/jit/benchmark.cc       | 21 ++++++++
 paddle/fluid/operators/jit/helper.cc          | 15 ++++++
 paddle/fluid/operators/jit/helper.h           |  6 +++
 paddle/fluid/operators/jit/kernel_base.h      | 19 ++++----
 paddle/fluid/operators/jit/refer/refer.h      |  2 +-
 paddle/fluid/operators/jit/test.cc            | 48 +++++++++++++++++++
 .../fluid/operators/math/sequence_pooling.cc  |  2 +-
 7 files changed, 103 insertions(+), 10 deletions(-)

diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc
index 437005825d..f64e43389a 100644
--- a/paddle/fluid/operators/jit/benchmark.cc
+++ b/paddle/fluid/operators/jit/benchmark.cc
@@ -190,6 +190,24 @@ void BenchGRUKernel() {
   }
 }
 
+template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
+void BenchSeqPoolKernel() {
+  std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
+  for (auto type : pool_types) {
+    for (int h : TestSizes()) {
+      for (int w : TestSizes()) {
+        const jit::seq_pool_attr_t attr(h, w, type);
+        std::vector<T> x(h * w), y(w);
+        RandomVec<T>(h * w, x.data(), -2.f, 2.f);
+        const T* x_data = x.data();
+        T* y_data = y.data();
+        BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
+                                                            y_data, &attr);
+      }
+    }
+  }
+}
+
 // Benchmark all jit kernels including jitcode, mkl and refer.
 // To use this tool, run command: ./benchmark [options...]
 // Options:
@@ -228,4 +246,7 @@ int main(int argc, char* argv[]) {
   BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
   BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
   BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
+
+  // seq pool function
+  BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
 }
diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc
index d00584baa0..7d02590f2e 100644
--- a/paddle/fluid/operators/jit/helper.cc
+++ b/paddle/fluid/operators/jit/helper.cc
@@ -26,6 +26,7 @@ namespace jit {
 
 const char* to_string(KernelType kt) {
   switch (kt) {
+    ONE_CASE(kNone);
     ONE_CASE(kVMul);
     ONE_CASE(kVAdd);
     ONE_CASE(kVAddRelu);
@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) {
     ONE_CASE(kCRFDecoding);
     ONE_CASE(kLayerNorm);
     ONE_CASE(kNCHW16CMulNC);
+    ONE_CASE(kSeqPool);
     default:
       PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
       return "NOT JITKernel";
   }
   return nullptr;
 }
+
+const char* to_string(SeqPoolType tp) {
+  switch (tp) {
+    ONE_CASE(kNonePoolType);
+    ONE_CASE(kSum);
+    ONE_CASE(kAvg);
+    ONE_CASE(kSqrt);
+    default:
+      PADDLE_THROW("Not support type: %d, or forget to add it.", tp);
+      return "NOT PoolType";
+  }
+  return nullptr;
+}
 #undef ONE_CASE
 
 KernelType to_kerneltype(const std::string& act) {
diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h
index 412df86aa1..fbf34fc4b3 100644
--- a/paddle/fluid/operators/jit/helper.h
+++ b/paddle/fluid/operators/jit/helper.h
@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get(
 }
 
 const char* to_string(KernelType kt);
+const char* to_string(SeqPoolType kt);
 
 KernelType to_kerneltype(const std::string& act);
 
@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
      << "],act_cand[" << to_string(attr.act_cand) << "]";
   return os;
 }
+inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
+  os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type["
+     << to_string(attr.type) << "]";
+  return os;
+}
 
 }  // namespace jit
 }  // namespace operators
diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h
index 8f13fbb16e..2659374650 100644
--- a/paddle/fluid/operators/jit/kernel_base.h
+++ b/paddle/fluid/operators/jit/kernel_base.h
@@ -44,6 +44,13 @@ typedef enum {
   kSeqPool,
 } KernelType;
 
+typedef enum {
+  kNonePoolType = 0,
+  kSum,
+  kAvg,
+  kSqrt,
+} SeqPoolType;
+
 template <typename T>
 struct XYZNTuples {
   typedef T data_type;
@@ -113,16 +120,12 @@ struct GRUTuples {
   typedef void (*func_type)(gru_t*, const gru_attr_t*);
 };
 
-typedef enum {
-  non = 0,
-  sum,
-  avg,
-  sqrt,
-} SeqPoolType;
-
-typedef struct {
+typedef struct seq_pool_attr_s {
   int h, w;
   SeqPoolType type;
+  seq_pool_attr_s() = default;
+  explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type)
+      : h(height), w(width), type(pool_type) {}
 } seq_pool_attr_t;
 
 template <typename T>
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index 52fe2de02a..c2aa922528 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -334,7 +334,7 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
 
 template <typename T>
 void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
-  PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet");
+  PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet");
   for (int w = 0; w < attr->w; ++w) {
     const T* src = x + w;
     T* dst = y + w;
diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc
index a73e2a60ae..0f1776507a 100644
--- a/paddle/fluid/operators/jit/test.cc
+++ b/paddle/fluid/operators/jit/test.cc
@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
   }
 };
 
+template <typename T>
+struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
+                         std::vector<T>> {
+  void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
+                  const std::vector<T>& x, const std::vector<T>& yref,
+                  const typename jit::SeqPoolTuples<T>::attr_type& attr) {
+    EXPECT_TRUE(tgt != nullptr);
+    EXPECT_EQ(x.size() % yref.size(), 0);
+    int w = yref.size();
+    std::vector<T> y(w);
+    const T* x_data = x.data();
+    const T* yref_data = yref.data();
+    T* y_data = y.data();
+    tgt(x_data, y_data, &attr);
+    ExpectEQ<T>(y_data, yref_data, w);
+  }
+};
+
 template <paddle::operators::jit::KernelType KT, typename KernelTuples,
           typename PlaceType, typename... Args>
 void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
@@ -415,6 +433,30 @@ void TestGRUKernel() {
   }
 }
 
+template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
+void TestSeqPoolKernel() {
+  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+  // TODO(TJ): support more
+  std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
+  for (auto type : pool_types) {
+    for (int h : TestSizes()) {
+      for (int w : TestSizes()) {
+        const jit::seq_pool_attr_t attr(h, w, type);
+        auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
+        EXPECT_TRUE(ref != nullptr);
+        std::vector<T> x(h * w), yref(w);
+        RandomVec<T>(h * w, x.data(), -2.f, 2.f);
+        const T* x_data = x.data();
+        T* yref_data = yref.data();
+        ref(x_data, yref_data, &attr);
+        VLOG(10) << attr;
+        TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
+                     std::vector<T>>(attr, x, yref, attr);
+      }
+    }
+  }
+}
+
 template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
 void TestNCHW16CMulNCKernel() {
   VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
@@ -569,6 +611,12 @@ TEST(JITKernel, kGRUHtPart2) {
   TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
 }
 
+TEST(JITKernel, kSeqPool) {
+  namespace jit = paddle::operators::jit;
+  TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
+  TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
+}
+
 TEST(JITKernel, kNCHW16CMulNC) {
   namespace jit = paddle::operators::jit;
   TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index 23dc516933..98707c936d 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -254,7 +254,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
       T* dst = output->mutable_data<T>(place);
       jit::seq_pool_attr_t attr;
       attr.w = input.numel() / input.dims()[0];
-      attr.type = jit::SeqPoolType::sum;
+      attr.type = jit::SeqPoolType::kSum;
       auto seqpool =
           jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
               attr);

From c50060bb264a3e70ef55abfdd8ab74416cb14121 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sat, 29 Dec 2018 06:26:02 +0000
Subject: [PATCH 04/28] add jitcode impl and use it

---
 paddle/fluid/operators/jit/gen/CMakeLists.txt |   1 +
 paddle/fluid/operators/jit/gen/seqpool.cc     | 132 ++++++++++++++++++
 paddle/fluid/operators/jit/gen/seqpool.h      |  98 +++++++++++++
 paddle/fluid/operators/jit/kernel_key.cc      |   7 +-
 .../fluid/operators/math/sequence_pooling.cc  |   6 +-
 5 files changed, 239 insertions(+), 5 deletions(-)
 create mode 100644 paddle/fluid/operators/jit/gen/seqpool.cc
 create mode 100644 paddle/fluid/operators/jit/gen/seqpool.h

diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt
index 8a54010830..2b8c758a03 100644
--- a/paddle/fluid/operators/jit/gen/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt
@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1)
 USE_JITKERNEL_GEN(kGRUHtPart1)
 USE_JITKERNEL_GEN(kGRUHtPart2)
 USE_JITKERNEL_GEN(kNCHW16CMulNC)
+USE_JITKERNEL_GEN(kSeqPool)
diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc
new file mode 100644
index 0000000000..ce6801b030
--- /dev/null
+++ b/paddle/fluid/operators/jit/gen/seqpool.cc
@@ -0,0 +1,132 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. */
+
+#include "paddle/fluid/operators/jit/gen/seqpool.h"
+#include "paddle/fluid/operators/jit/registry.h"
+#include "paddle/fluid/platform/cpu_info.h"
+
+namespace paddle {
+namespace operators {
+namespace jit {
+namespace gen {
+
+void SeqPoolJitCode::genCode() {
+  constexpr int block = YMM_FLOAT_BLOCK;
+  constexpr int max_num_regs = 8;
+  const int num_block = w_ / block;
+  const int num_groups = num_block / max_num_regs;
+  int rest_num_regs = num_block % max_num_regs;
+  if (type_ == SeqPoolType::kAvg) {
+    float scalar = 1.f / h_;
+    mov(reg32_scalar, scalar);
+  } else if (type_ == SeqPoolType::kSqrt) {
+    float scalar = 1.f / std::sqrt(static_cast<float>(h_));
+    mov(reg32_scalar, scalar);
+  }
+
+  // TODO(TJ): make height load from params
+  const int group_len = max_num_regs * block * sizeof(float);
+  for (int g = 0; g < num_groups; ++g) {
+    pool_height<ymm_t>(g * group_len, block, max_num_regs);
+  }
+  if (rest_num_regs > 0) {
+    pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
+  }
+
+  // rest part
+  const int rest = w_ % block;
+  const bool has_block4 = rest / 4 > 0;
+  const bool has_block2 = (rest % 4) / 2 > 0;
+  const bool has_block1 = (rest % 2) == 1;
+  const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float);
+  for (int h = 0; h < h_; ++h) {
+    int offset = h * w_ * sizeof(float) + w_offset;
+    const int shift_regs = (h == 0) ? 0 : max_num_regs;
+    int reg_idx = 0;
+    if (has_block4) {
+      vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
+      offset += sizeof(float) * 4;
+      reg_idx++;
+    }
+    if (has_block2) {
+      vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
+      offset += sizeof(float) * 2;
+      reg_idx++;
+    }
+    if (has_block1) {
+      vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
+      reg_idx++;
+    }
+    rest_num_regs = reg_idx;
+    if (h > 0) {
+      for (int i = 0; i < reg_idx; ++i) {
+        vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
+      }
+    }
+  }
+  // save right now
+  int offset = w_offset;
+  if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
+    vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
+    for (int i = 0; i < rest_num_regs; ++i) {
+      vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
+    }
+  }
+  int reg_idx = 0;
+  if (has_block4) {
+    vmovups(ptr[param2 + offset], xmm_t(reg_idx));
+    offset += sizeof(float) * 4;
+    reg_idx++;
+  }
+  if (has_block2) {
+    vmovq(ptr[param2 + offset], xmm_t(reg_idx));
+    offset += sizeof(float) * 2;
+    reg_idx++;
+  }
+  if (has_block1) {
+    vmovss(ptr[param2 + offset], xmm_t(reg_idx));
+  }
+  ret();
+}
+
+class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
+ public:
+  bool UseMe(const seq_pool_attr_t& attr) const override {
+    return platform::MayIUse(platform::avx);
+  }
+  size_t CodeSize(const seq_pool_attr_t& attr) const override {
+    // TODO(TJ): remove attr.h when enabled height
+    bool yes =
+        attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt;
+    return 96 /* basic */ +
+           ((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */
+            * (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) *
+               8;
+  }
+  std::unique_ptr<GenBase> CreateJitCode(
+      const seq_pool_attr_t& attr) const override {
+    PADDLE_ENFORCE_GT(attr.w, 0);
+    PADDLE_ENFORCE_GT(attr.h, 0);
+    return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
+  }
+};
+
+}  // namespace gen
+}  // namespace jit
+}  // namespace operators
+}  // namespace paddle
+
+namespace gen = paddle::operators::jit::gen;
+
+REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator);
diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h
new file mode 100644
index 0000000000..eb2d191382
--- /dev/null
+++ b/paddle/fluid/operators/jit/gen/seqpool.h
@@ -0,0 +1,98 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License. */
+
+#pragma once
+
+#include <string>
+#include "glog/logging.h"
+#include "paddle/fluid/operators/jit/gen/jitcode.h"
+
+namespace paddle {
+namespace operators {
+namespace jit {
+namespace gen {
+
+class SeqPoolJitCode : public JitCode {
+ public:
+  explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
+                          size_t code_size = 256 * 1024,
+                          void* code_ptr = nullptr)
+      : JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) {
+    if (type_ != SeqPoolType::kSum) {
+      LOG(FATAL) << "Only support sum pool yet ";
+    }
+    this->genCode();
+  }
+
+  virtual const char* name() const {
+    std::string base = "SeqPoolJitCode";
+    if (type_ == SeqPoolType::kSum) {
+      base += "_Sum";
+    } else if (type_ == SeqPoolType::kAvg) {
+      base += "_Avg";
+    } else if (type_ == SeqPoolType::kSqrt) {
+      base += "_Sqrt";
+    }
+    base += ("_W" + std::to_string(w_));
+    // TODO(TJ): make h load from params
+    base += ("_H" + std::to_string(h_));
+    return base.c_str();
+  }
+  void genCode() override;
+
+ protected:
+  template <typename JMM>
+  void pool_height(int w_offset, int block, int max_num_regs) {
+    for (int h = 0; h < h_; ++h) {
+      int offset = h * w_ * sizeof(float) + w_offset;
+      const int shift_regs = (h == 0) ? 0 : max_num_regs;
+      for (int i = 0; i < max_num_regs; ++i) {
+        vmovups(JMM(i + shift_regs), ptr[param1 + offset]);
+        offset += sizeof(float) * block;
+      }
+      if (h > 0) {
+        // sum anyway
+        for (int i = 0; i < max_num_regs; ++i) {
+          vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
+        }
+      }
+    }
+    // save right now
+    if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
+      vbroadcastss(JMM(max_num_regs), reg32_scalar);
+    }
+    int offset = w_offset;
+    for (int i = 0; i < max_num_regs; ++i) {
+      if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
+        vmulps(JMM(i), JMM(i), JMM(max_num_regs));
+      }
+      vmovups(ptr[param2 + offset], JMM(i));
+      offset += sizeof(float) * block;
+    }
+  }
+
+ private:
+  int h_;
+  int w_;
+  SeqPoolType type_;
+  reg64_t param1{abi_param1};
+  reg64_t param2{abi_param2};
+  reg64_t param3{abi_param3};
+  reg32_t reg32_scalar{r8d};
+};
+
+}  // namespace gen
+}  // namespace jit
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc
index 6b0025a75a..db78ed8ad8 100644
--- a/paddle/fluid/operators/jit/kernel_key.cc
+++ b/paddle/fluid/operators/jit/kernel_key.cc
@@ -44,8 +44,11 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
 
 template <>
 size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
-  size_t key = static_cast<size_t>(attr.type);
-  return key + (attr.w << act_type_shift);
+  size_t key = attr.w;
+  // TODO(TJ): support height, then removed it from key
+  constexpr int w_shift = 30;
+  return (key << act_type_shift) + static_cast<int>(attr.type) +
+         (static_cast<size_t>(attr.h) << (act_type_shift + w_shift));
 }
 
 }  // namespace jit
diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index 98707c936d..283e2e251a 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -255,11 +255,11 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
       jit::seq_pool_attr_t attr;
       attr.w = input.numel() / input.dims()[0];
       attr.type = jit::SeqPoolType::kSum;
-      auto seqpool =
-          jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
-              attr);
       for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
         attr.h = static_cast<int>(lod[i + 1] - lod[i]);
+        auto seqpool =
+            jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
+                attr);
         seqpool(src, dst, &attr);
         dst += attr.w;
         src += attr.h * attr.w;

From 92201d3956a4f64615baf5bc9e979bcfc6bd09bd Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Fri, 4 Jan 2019 06:41:40 +0000
Subject: [PATCH 05/28] support avg and sqrt pool and add  mkl impl

test=develop
---
 .../operators/jit/more/mkl/CMakeLists.txt     |  1 +
 paddle/fluid/operators/jit/more/mkl/mkl.cc    | 31 +++++++++++++++++++
 paddle/fluid/operators/jit/more/mkl/mkl.h     | 26 ++++++++++++++++
 paddle/fluid/operators/jit/refer/refer.h      |  9 ++++++
 4 files changed, 67 insertions(+)

diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
index 863cc720d6..f5ed2f0572 100644
--- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl)
 USE_JITKERNEL_MORE(kVExp, mkl)
 USE_JITKERNEL_MORE(kVSigmoid, mkl)
 USE_JITKERNEL_MORE(kVTanh, mkl)
+USE_JITKERNEL_MORE(kSeqPool, mkl)
diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc
index a5b088d481..5a499ac2c0 100644
--- a/paddle/fluid/operators/jit/more/mkl/mkl.cc
+++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc
@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) {
   platform::dynload::vdExp(n, x, y);
 }
 
+template <>
+void VCopy<float>(const float* x, float* y, int n) {
+  platform::dynload::cblas_scopy(n, x, 1, y, 1);
+}
+
+template <>
+void VCopy<double>(const double* x, double* y, int n) {
+  platform::dynload::cblas_dcopy(n, x, 1, y, 1);
+}
+
+template <>
+void VAXPY<float>(float a, const float* x, float* y, int n) {
+  platform::dynload::cblas_saxpy(n, a, x, 1, y, 1);
+}
+
+template <>
+void VAXPY<double>(double a, const double* x, double* y, int n) {
+  platform::dynload::cblas_daxpy(n, a, x, 1, y, 1);
+}
+
 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
 template <>
 bool VMulKernel<float>::UseMe(const int& d) const {
@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const {
   return d > 7;
 }
 
+template <>
+bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
+  return true;
+}
+
+template <>
+bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
+  return true;
+}
+
 #define AWALYS_USE_ME_WITH_DOUBLE(func)                  \
   template <>                                            \
   bool func##Kernel<double>::UseMe(const int& d) const { \
@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal);
 REGISTER_MKL_KERNEL(kVExp, VExp);
 REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
 REGISTER_MKL_KERNEL(kVTanh, VTanh);
+REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
 
 #undef REGISTER_MKL_KERNEL
diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h
index ee1031c028..0a3816db24 100644
--- a/paddle/fluid/operators/jit/more/mkl/mkl.h
+++ b/paddle/fluid/operators/jit/more/mkl/mkl.h
@@ -14,6 +14,7 @@
 
 #pragma once
 
+#include <cmath>
 #include <type_traits>
 #include "paddle/fluid/operators/jit/kernel_base.h"
 
@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n);
 template <typename T>
 void VExp(const T* x, T* y, int n);
 
+template <typename T>
+void VCopy(const T* x, T* y, int n);
+
+template <typename T>
+void VAXPY(T a, const T* x, T* y, int n);
+
 template <typename T>
 void VSigmoid(const T* x, T* y, int n) {
   const T min = SIGMOID_THRESHOLD_MIN;
@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) {
   }
 }
 
+template <typename T>
+void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
+  VCopy<T>(x, y, attr->w);
+  for (int h = 1; h != attr->h; ++h) {
+    VAXPY<T>(static_cast<T>(1), x + h * attr->w, y, attr->w);
+  }
+  if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
+    T scalar = static_cast<T>(1);
+    if (attr->type == SeqPoolType::kAvg) {
+      scalar = scalar / static_cast<T>(attr->h);
+    } else {
+      scalar = scalar / std::sqrt(static_cast<T>(attr->h));
+    }
+    VScal<T>(&scalar, y, y, attr->w);
+  }
+}
+
 #define DECLARE_MKL_KERNEL(name, tuples)                             \
   template <typename T>                                              \
   class name##Kernel : public KernelMore<tuples<T>> {                \
@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
 DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
 DECLARE_MKL_KERNEL(VTanh, XYNTuples);
 
+DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
+
 #undef DECLARE_MKL_KERNEL
 
 }  // namespace mkl
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index c2aa922528..4e19783c86 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -344,6 +344,15 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
       src += attr->w;
     }
   }
+  if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
+    T scalar = static_cast<T>(1);
+    if (attr->type == SeqPoolType::kAvg) {
+      scalar = scalar / static_cast<T>(attr->h);
+    } else {
+      scalar = scalar / std::sqrt(static_cast<T>(attr->h));
+    }
+    VScal<T>(&scalar, y, y, attr->w);
+  }
 }
 
 #define DECLARE_REFER_KERNEL(name, tuples)             \

From f4c990e7b8493304b61249417aaaca45d95e5174 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 12:54:37 +0800
Subject: [PATCH 06/28] Add fused embedding ops

---
 .../fused/fused_embedding_seq_pool_op.cc      | 194 ++++++++++++++++++
 .../fused/fused_embedding_seq_pool_op.h       | 142 +++++++++++++
 2 files changed, 336 insertions(+)
 create mode 100644 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
 create mode 100644 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h

diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
new file mode 100644
index 0000000000..fe4c73f472
--- /dev/null
+++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
@@ -0,0 +1,194 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h"
+#include "paddle/fluid/framework/var_type_inference.h"
+
+namespace paddle {
+namespace operators {
+
+class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("W"),
+                   "Input W of FusedEmbeddingSeqPoolOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Ids"),
+                   "Input Ids of FusedEmbeddingSeqPoolOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("Out"),
+                   "Output of FusedEmbeddingSeqPoolOp should not be null.");
+
+    auto table_dims = ctx->GetInputDim("W");
+    auto ids_dims = ctx->GetInputDim("Ids");
+    const std::string& combiner = ctx->Attrs().Get<std::string>("combiner");
+
+    PADDLE_ENFORCE_EQ(table_dims.size(), 2);
+    PADDLE_ENFORCE_GE(ids_dims.size(), 1,
+                      "The dim size of the 'Ids' tensor must greater than 1.");
+    PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1,
+                      "The last dimension of the 'Ids' tensor must be 1.");
+    // we only support sum now
+    PADDLE_ENFORCE_EQ(combiner, "sum");
+
+    int64_t last_dim = table_dims[1];
+    for (int i = 1; i != ids_dims.size(); ++i) {
+      last_dim *= ids_dims[i];
+    }
+
+    if (ctx->IsRuntime()) {
+      framework::Variable* ids_var =
+          boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
+      const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
+
+      // in run time, the LoD of ids must be 1
+      PADDLE_ENFORCE(ids_lod.size(), 1u,
+                     "The LoD level of Input(Ids) must be 1");
+      PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
+
+      int64_t batch_size = ids_lod[0].size() - 1;
+
+      // in run time, the shape from Ids -> output
+      // should be [seq_length, 1] -> [batch_size, embedding_size]
+      ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim}));
+    } else {
+      // in compile time, the lod level of ids must be 1
+      framework::VarDesc* ids_desc =
+          boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
+      PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
+
+      // in compile time, the shape from Ids -> output
+      // should be [-1, 1] -> [-1, embedding_size]
+      ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
+    }
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
+    return framework::OpKernelType(data_type, ctx.device_context());
+  }
+};
+
+class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("W",
+             "(Tensor) The input represents embedding tensors, "
+             "which is a learnable parameter.");
+    AddInput("Ids",
+             "An input with type int32 or int64 "
+             "contains the ids to be looked up in W. "
+             "The last dimension size must be 1.");
+    AddOutput("Out", "The lookup results, which have the same type as W.");
+    AddAttr<std::string>("combiner",
+                         "(string, default sum) "
+                         "A string specifying the reduction op. Currently sum "
+                         "are supported, sum computes the weighted sum of the "
+                         "embedding results for each row.")
+        .SetDefault("sum");
+    // NOTE(minqiyang): grad_inplace is an temporal attribute,
+    // please do NOT set this attribute in python layer.
+    AddAttr<bool>("grad_inplace",
+                  "(boolean, default false) "
+                  "If the grad op reuse the input's variable.")
+        .SetDefault(false);
+    AddAttr<bool>("is_sparse",
+                  "(boolean, default false) "
+                  "Sparse update.")
+        .SetDefault(false);
+    AddComment(R"DOC(
+FusedEmbeddingSeqPool Operator.
+
+Computes embeddings for the given ids and weights.
+
+This operator is used to perform lookups on the parameter W,
+then computes the weighted sum of the lookups results for each row
+and concatenated into a dense tensor.
+
+The input Ids should carry the LoD (Level of Details) information.
+And the output will change the LoD information with input Ids.
+
+)DOC");
+  }
+};
+
+class FusedEmbeddingSeqPoolOpGradDescMaker
+    : public framework::DefaultGradOpDescMaker<true> {
+  using ::paddle::framework::DefaultGradOpDescMaker<
+      true>::DefaultGradOpDescMaker;
+
+ protected:
+  virtual std::string GradOpType() const {
+    return "fused_embedding_seq_pool_grad";
+  }
+};
+
+class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    auto table_dims = ctx->GetInputDim("W");
+    ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
+    return framework::OpKernelType(data_type, ctx.device_context());
+  }
+};
+
+class FusedEmbeddingSeqPoolOpGradVarTypeInference
+    : public framework::VarTypeInference {
+ public:
+  void operator()(const framework::OpDesc& op_desc,
+                  framework::BlockDesc* block) const override {
+    auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
+    auto attr = op_desc.GetAttr("is_sparse");
+    bool is_sparse = boost::get<bool>(attr);
+    if (is_sparse) {
+      VLOG(3) << "fused_embedding_seq_pool_grad op "
+              << framework::GradVarName("W") << " is set to SelectedRows";
+      block->Var(out_var_name)
+          ->SetType(framework::proto::VarType::SELECTED_ROWS);
+    } else {
+      VLOG(3) << "fused_embedding_seq_pool_grad op "
+              << framework::GradVarName("W") << " is set to LoDTensor";
+      block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
+    }
+    block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(fused_embedding_seq_pool, ops::FusedEmbeddingSeqPoolOp,
+                  ops::FusedEmbeddingSeqPoolOpGradDescMaker,
+                  ops::FusedEmbeddingSeqPoolOpMaker);
+REGISTER_OPERATOR(fused_embedding_seq_pool_grad,
+                  ops::FusedEmbeddingSeqPoolOpGrad,
+                  ops::FusedEmbeddingSeqPoolOpGradVarTypeInference);
+
+REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool,
+                       ops::FusedEmbeddingSeqPoolKernel<float>,
+                       ops::FusedEmbeddingSeqPoolKernel<double>);
+REGISTER_OP_CPU_KERNEL(fused_embedding_seq_pool_grad,
+                       ops::FusedEmbeddingSeqPoolGradKernel<float>,
+                       ops::FusedEmbeddingSeqPoolGradKernel<double>);
diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
new file mode 100644
index 0000000000..38dfae8ad6
--- /dev/null
+++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
@@ -0,0 +1,142 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include <string>
+#include <vector>
+
+#include "paddle/fluid/framework/eigen.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/selected_rows.h"
+#include "paddle/fluid/operators/math/blas.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+using SelectedRows = framework::SelectedRows;
+using DDim = framework::DDim;
+
+template <typename T>
+struct EmbeddingVSumFunctor {
+  void operator()(const framework::ExecutionContext &context,
+                  const LoDTensor *table_t, const LoDTensor *ids_t,
+                  LoDTensor *output_t) {
+    auto *table = table_t->data<T>();
+    int64_t row_number = table_t->dims()[0];
+    int64_t row_width = table_t->dims()[1];
+    int64_t last_dim = output_t->dims()[1];
+    int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
+    auto ids_lod = ids_t->lod()[0];
+    int64_t ids_count = ids_t->numel() / ids_lod.back();
+
+    auto *output = output_t->mutable_data<T>(context.GetPlace());
+
+    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
+    for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
+      size_t begin = ids_lod[i] * ids_count;
+      for (int64_t j = 0; j != ids_count; ++j) {
+        PADDLE_ENFORCE_LT(ids[begin], row_number);
+        PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
+        blas.VCOPY(row_width, table + ids[begin + j] * row_width,
+                   output + i * last_dim + j * row_width);
+      }
+
+      for (int64_t r = (ids_lod[i] + 1) * ids_count;
+           r < ids_lod[i + 1] * ids_count; ++r) {
+        PADDLE_ENFORCE_LT(ids[r], row_number);
+        PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i);
+        blas.AXPY(row_width, 1., table + ids[r] * row_width,
+                  output + i * last_dim + (r % ids_count) * row_width);
+      }
+    }
+  }
+};
+
+template <typename T>
+class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext &context) const override {
+    const LoDTensor *ids_t = context.Input<LoDTensor>("Ids");  // int tensor
+    LoDTensor *output_t = context.Output<LoDTensor>("Out");    // float tensor
+    const LoDTensor *table_var = context.Input<LoDTensor>("W");
+    const std::string &combiner_type = context.Attr<std::string>("combiner");
+
+    if (combiner_type == "sum") {
+      EmbeddingVSumFunctor<T> functor;
+      functor(context, table_var, ids_t, output_t);
+    }
+  }
+};
+
+template <typename T>
+class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext &context) const override {
+    auto *table_var = context.InputVar("W");
+    DDim table_dim;
+    if (table_var->IsType<LoDTensor>()) {
+      table_dim = context.Input<LoDTensor>("W")->dims();
+    } else if (table_var->IsType<SelectedRows>()) {
+      auto *table_t = context.Input<SelectedRows>("W");
+      table_dim = table_t->value().dims();
+    } else {
+      PADDLE_THROW(
+          "The parameter W of a LookupTable "
+          "must be either LoDTensor or SelectedRows");
+    }
+
+    bool is_sparse = context.Attr<bool>("is_sparse");
+    // Since paddings are not trainable and fixed in forward, the gradient of
+    // paddings makes no sense and we don't deal with it in backward.
+    if (is_sparse) {
+      auto *ids = context.Input<LoDTensor>("Ids");
+      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
+      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
+
+      auto *ids_data = ids->data<int64_t>();
+      int64_t ids_num = ids->numel();
+      auto lod = ids->lod()[0];
+      int64_t row_width = d_output->dims()[1];
+
+      framework::Vector<int64_t> *new_rows = d_table->mutable_rows();
+      new_rows->resize(ids_num);
+      std::memcpy(&(*new_rows)[0], ids_data, ids_num * sizeof(int64_t));
+
+      auto *d_table_value = d_table->mutable_value();
+      d_table_value->Resize({ids_num, table_dim[1]});
+      T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
+      const T *d_output_data = d_output->data<T>();
+
+      auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
+      for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
+        int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
+        int64_t in_offset = lod[i] * row_width;
+        const T *out_pos = d_output_data + i * row_width;
+        T *in_pos = d_table_data + in_offset;
+        for (int r = 0; r != h; ++r) {
+          blas.VCOPY(row_width, out_pos, in_pos + r * row_width);
+        }
+      }
+    } else {
+      LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle

From dc0ecffd6c4115019cfcbcc13b17a20511888c9b Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 12:55:03 +0800
Subject: [PATCH 07/28] Add ut for fused ops

---
 .../unittests/test_fused_emb_seq_pool_op.py   | 51 +++++++++++++++++++
 1 file changed, 51 insertions(+)
 create mode 100644 python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py

diff --git a/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py
new file mode 100644
index 0000000000..584e309bef
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py
@@ -0,0 +1,51 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import numpy as np
+from op_test import OpTest
+import paddle.fluid.core as core
+import paddle.fluid as fluid
+from paddle.fluid.op import Operator
+import paddle.compat as cpt
+
+
+class TestFusedEmbeddingSeqPoolOp(OpTest):
+    def setUp(self):
+        self.op_type = "fused_embedding_seq_pool"
+        self.emb_size = 2
+        table = np.random.random((17, self.emb_size)).astype("float32")
+        ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]],
+                        [[16], [1]]]).astype("int64")
+        merged_ids = np.array([4, 2, 16]).astype("int64")
+        ids_expand = np.expand_dims(ids, axis=1)
+        self.lod = [[3, 1]]
+        self.attrs = {'is_sparse': True}
+        self.inputs = {'W': table, 'Ids': (ids_expand, self.lod)}
+        self.outputs = {
+            'Out': np.reshape(
+                np.array([
+                    table[[4, 3]] + table[[4, 3]] + table[[2, 1]],
+                    table[[16, 1]]
+                ]), [len(self.lod[0]), 2 * self.emb_size])
+        }
+
+    def test_check_output(self):
+        self.check_output()
+
+
+if __name__ == "__main__":
+    unittest.main()

From e0591deebc02202c4ae8bfc95f31be606b8192b8 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Fri, 4 Jan 2019 14:40:43 +0000
Subject: [PATCH 08/28] enhance seqpool jitcode

---
 paddle/fluid/operators/jit/benchmark.cc   |   4 +-
 paddle/fluid/operators/jit/gen/seqpool.cc |  55 +--------
 paddle/fluid/operators/jit/gen/seqpool.h  | 134 ++++++++++++++++++++--
 3 files changed, 126 insertions(+), 67 deletions(-)

diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc
index f64e43389a..37a552fb6d 100644
--- a/paddle/fluid/operators/jit/benchmark.cc
+++ b/paddle/fluid/operators/jit/benchmark.cc
@@ -194,8 +194,8 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
 void BenchSeqPoolKernel() {
   std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
   for (auto type : pool_types) {
-    for (int h : TestSizes()) {
-      for (int w : TestSizes()) {
+    for (int w : TestSizes()) {
+      for (int h : TestSizes()) {
         const jit::seq_pool_attr_t attr(h, w, type);
         std::vector<T> x(h * w), y(w);
         RandomVec<T>(h * w, x.data(), -2.f, 2.f);
diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc
index ce6801b030..fd83f83436 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.cc
+++ b/paddle/fluid/operators/jit/gen/seqpool.cc
@@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() {
     mov(reg32_scalar, scalar);
   }
 
-  // TODO(TJ): make height load from params
   const int group_len = max_num_regs * block * sizeof(float);
   for (int g = 0; g < num_groups; ++g) {
     pool_height<ymm_t>(g * group_len, block, max_num_regs);
@@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() {
     pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
   }
 
-  // rest part
+  // part of rest_w * height
   const int rest = w_ % block;
-  const bool has_block4 = rest / 4 > 0;
-  const bool has_block2 = (rest % 4) / 2 > 0;
-  const bool has_block1 = (rest % 2) == 1;
-  const int w_offset = num_block * YMM_FLOAT_BLOCK * sizeof(float);
-  for (int h = 0; h < h_; ++h) {
-    int offset = h * w_ * sizeof(float) + w_offset;
-    const int shift_regs = (h == 0) ? 0 : max_num_regs;
-    int reg_idx = 0;
-    if (has_block4) {
-      vmovups(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
-      offset += sizeof(float) * 4;
-      reg_idx++;
-    }
-    if (has_block2) {
-      vmovq(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
-      offset += sizeof(float) * 2;
-      reg_idx++;
-    }
-    if (has_block1) {
-      vmovss(xmm_t(reg_idx + shift_regs), ptr[param1 + offset]);
-      reg_idx++;
-    }
-    rest_num_regs = reg_idx;
-    if (h > 0) {
-      for (int i = 0; i < reg_idx; ++i) {
-        vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
-      }
-    }
-  }
-  // save right now
-  int offset = w_offset;
-  if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
-    vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
-    for (int i = 0; i < rest_num_regs; ++i) {
-      vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
-    }
-  }
-  int reg_idx = 0;
-  if (has_block4) {
-    vmovups(ptr[param2 + offset], xmm_t(reg_idx));
-    offset += sizeof(float) * 4;
-    reg_idx++;
-  }
-  if (has_block2) {
-    vmovq(ptr[param2 + offset], xmm_t(reg_idx));
-    offset += sizeof(float) * 2;
-    reg_idx++;
-  }
-  if (has_block1) {
-    vmovss(ptr[param2 + offset], xmm_t(reg_idx));
-  }
+  pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
   ret();
 }
 
diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h
index eb2d191382..48288d8c2a 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.h
+++ b/paddle/fluid/operators/jit/gen/seqpool.h
@@ -17,6 +17,7 @@
 #include <string>
 #include "glog/logging.h"
 #include "paddle/fluid/operators/jit/gen/jitcode.h"
+#include "paddle/fluid/platform/enforce.h"
 
 namespace paddle {
 namespace operators {
@@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode {
       base += "_Sqrt";
     }
     base += ("_W" + std::to_string(w_));
-    // TODO(TJ): make h load from params
-    base += ("_H" + std::to_string(h_));
     return base.c_str();
   }
   void genCode() override;
@@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode {
  protected:
   template <typename JMM>
   void pool_height(int w_offset, int block, int max_num_regs) {
-    for (int h = 0; h < h_; ++h) {
-      int offset = h * w_ * sizeof(float) + w_offset;
-      const int shift_regs = (h == 0) ? 0 : max_num_regs;
-      for (int i = 0; i < max_num_regs; ++i) {
-        vmovups(JMM(i + shift_regs), ptr[param1 + offset]);
-        offset += sizeof(float) * block;
-      }
-      if (h > 0) {
-        // sum anyway
+    int offset = w_offset;
+    for (int i = 0; i < max_num_regs; ++i) {
+      vmovups(JMM(i), ptr[param1 + offset]);
+      offset += sizeof(float) * block;
+    }
+    if (h_ > 1) {
+      Label l_next_h;
+      mov(reg_h, 1);
+      mov(reg_tmp, param1);
+      add(reg_tmp, w_ * sizeof(float) + w_offset);
+      L(l_next_h);
+      {
+        mov(reg_ptr_src_i, reg_tmp);
         for (int i = 0; i < max_num_regs; ++i) {
+          vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
+          // sum anyway
           vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
+          add(reg_ptr_src_i, sizeof(float) * block);
         }
+        inc(reg_h);
+        add(reg_tmp, w_ * sizeof(float));
+        cmp(reg_h, h_);
+        jl(l_next_h, T_NEAR);
       }
     }
     // save right now
     if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
       vbroadcastss(JMM(max_num_regs), reg32_scalar);
     }
-    int offset = w_offset;
+    offset = w_offset;
     for (int i = 0; i < max_num_regs; ++i) {
       if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
         vmulps(JMM(i), JMM(i), JMM(max_num_regs));
@@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode {
     }
   }
 
+  void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
+    const int rest_used_num_regs = load_rest(rest, w_offset, 0);
+    const bool has_block4 = rest / 4 > 0;
+    const bool has_block2 = (rest % 4) / 2 > 0;
+    const bool has_block1 = (rest % 2) == 1;
+    if (h_ > 1) {
+      Label l_next_h;
+      mov(reg_h, 1);
+      mov(reg_tmp, param1);
+      add(reg_tmp, w_ * sizeof(float) + w_offset);
+      L(l_next_h);
+      {
+        // int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
+        // max_num_regs);
+        int reg_idx = 0;
+        mov(reg_ptr_src_i, reg_tmp);
+        if (has_block4) {
+          vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
+          add(reg_ptr_src_i, sizeof(float) * 4);
+          reg_idx++;
+        }
+        if (has_block2) {
+          vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
+          add(reg_ptr_src_i, sizeof(float) * 2);
+          reg_idx++;
+        }
+        if (has_block1) {
+          vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
+          reg_idx++;
+        }
+        PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
+                          "All heights should use same regs");
+        for (int i = 0; i < reg_idx; ++i) {
+          vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
+        }
+        inc(reg_h);
+        add(reg_tmp, w_ * sizeof(float));
+        cmp(reg_h, h_);
+        jl(l_next_h, T_NEAR);
+      }
+    }
+    // save right now
+    if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
+      vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
+      for (int i = 0; i < rest_used_num_regs; ++i) {
+        vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
+      }
+    }
+    save_rest(rest, w_offset);
+  }
+
+  // return the number of used regs, use start from reg 0
+  int load_rest(int rest, int w_offset, const int num_shift_regs,
+                const int reg_start = 0) {
+    const bool has_block4 = rest / 4 > 0;
+    const bool has_block2 = (rest % 4) / 2 > 0;
+    const bool has_block1 = (rest % 2) == 1;
+    int reg_idx = reg_start;
+    if (has_block4) {
+      vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
+      w_offset += sizeof(float) * 4;
+      reg_idx++;
+    }
+    if (has_block2) {
+      vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
+      w_offset += sizeof(float) * 2;
+      reg_idx++;
+    }
+    if (has_block1) {
+      vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
+      reg_idx++;
+    }
+    return reg_idx;
+  }
+
+  // use reg start from 0
+  void save_rest(int rest, int w_offset, int reg_start = 0) {
+    const bool has_block4 = rest / 4 > 0;
+    const bool has_block2 = (rest % 4) / 2 > 0;
+    const bool has_block1 = (rest % 2) == 1;
+    int reg_idx = reg_start;
+    if (has_block4) {
+      vmovups(ptr[param2 + w_offset], xmm_t(reg_idx));
+      w_offset += sizeof(float) * 4;
+      reg_idx++;
+    }
+    if (has_block2) {
+      vmovq(ptr[param2 + w_offset], xmm_t(reg_idx));
+      w_offset += sizeof(float) * 2;
+      reg_idx++;
+    }
+    if (has_block1) {
+      vmovss(ptr[param2 + w_offset], xmm_t(reg_idx));
+    }
+  }
+
  private:
   int h_;
   int w_;
@@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode {
   reg64_t param2{abi_param2};
   reg64_t param3{abi_param3};
   reg32_t reg32_scalar{r8d};
+
+  reg64_t reg_h{r9};
+  reg64_t reg_ptr_src_i{r10};
+  reg64_t reg_tmp{r11};
 };
 
 }  // namespace gen

From 0145f40f4576fa035b92e3876ca9c4cfefbc5c52 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Sat, 5 Jan 2019 11:34:15 +0000
Subject: [PATCH 09/28] use height from params of jitcode

---
 paddle/fluid/operators/jit/benchmark.cc       |   3 +-
 paddle/fluid/operators/jit/gen/seqpool.cc     |  17 +-
 paddle/fluid/operators/jit/gen/seqpool.h      | 162 ++++++++++--------
 paddle/fluid/operators/jit/kernel_base.h      |   6 +-
 paddle/fluid/operators/jit/kernel_key.cc      |   6 +-
 paddle/fluid/operators/jit/refer/refer.h      |   1 -
 paddle/fluid/operators/jit/test.cc            |   7 +-
 .../fluid/operators/math/sequence_pooling.cc  |  12 +-
 8 files changed, 117 insertions(+), 97 deletions(-)

diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc
index 37a552fb6d..4cbada4a5b 100644
--- a/paddle/fluid/operators/jit/benchmark.cc
+++ b/paddle/fluid/operators/jit/benchmark.cc
@@ -195,8 +195,9 @@ void BenchSeqPoolKernel() {
   std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
   for (auto type : pool_types) {
     for (int w : TestSizes()) {
+      jit::seq_pool_attr_t attr(w, type);
       for (int h : TestSizes()) {
-        const jit::seq_pool_attr_t attr(h, w, type);
+        attr.h = h;
         std::vector<T> x(h * w), y(w);
         RandomVec<T>(h * w, x.data(), -2.f, 2.f);
         const T* x_data = x.data();
diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc
index fd83f83436..d651f282bf 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.cc
+++ b/paddle/fluid/operators/jit/gen/seqpool.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/gen/seqpool.h"
+#include <stddef.h>  // offsetof
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -21,20 +22,22 @@ namespace operators {
 namespace jit {
 namespace gen {
 
+thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = {
+    1.f};  // TODO(TJ): try move to private
+
 void SeqPoolJitCode::genCode() {
   constexpr int block = YMM_FLOAT_BLOCK;
   constexpr int max_num_regs = 8;
   const int num_block = w_ / block;
   const int num_groups = num_block / max_num_regs;
   int rest_num_regs = num_block % max_num_regs;
-  if (type_ == SeqPoolType::kAvg) {
-    float scalar = 1.f / h_;
-    mov(reg32_scalar, scalar);
-  } else if (type_ == SeqPoolType::kSqrt) {
-    float scalar = 1.f / std::sqrt(static_cast<float>(h_));
-    mov(reg32_scalar, scalar);
+  mov(reg32_int_h, dword[param_attr]);
+  if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
+    mov(reg_tmp, reinterpret_cast<size_t>(float_h));
+    fild(dword[param_attr]);
+    fstp(dword[reg_tmp]);
+    mov(reg32_fp_h, dword[reg_tmp]);
   }
-
   const int group_len = max_num_regs * block * sizeof(float);
   for (int g = 0; g < num_groups; ++g) {
     pool_height<ymm_t>(g * group_len, block, max_num_regs);
diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h
index 48288d8c2a..c61bf27cc1 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.h
+++ b/paddle/fluid/operators/jit/gen/seqpool.h
@@ -16,6 +16,7 @@
 
 #include <string>
 #include "glog/logging.h"
+#include "paddle/fluid/operators/jit/gen/act.h"  // for ones
 #include "paddle/fluid/operators/jit/gen/jitcode.h"
 #include "paddle/fluid/platform/enforce.h"
 
@@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode {
   explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
                           size_t code_size = 256 * 1024,
                           void* code_ptr = nullptr)
-      : JitCode(code_size, code_ptr), h_(attr.h), w_(attr.w), type_(attr.type) {
+      : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
     if (type_ != SeqPoolType::kSum) {
       LOG(FATAL) << "Only support sum pool yet ";
     }
@@ -55,39 +56,48 @@ class SeqPoolJitCode : public JitCode {
   void pool_height(int w_offset, int block, int max_num_regs) {
     int offset = w_offset;
     for (int i = 0; i < max_num_regs; ++i) {
-      vmovups(JMM(i), ptr[param1 + offset]);
+      vmovups(JMM(i), ptr[param_src + offset]);
       offset += sizeof(float) * block;
     }
-    if (h_ > 1) {
-      Label l_next_h;
-      mov(reg_h, 1);
-      mov(reg_tmp, param1);
-      add(reg_tmp, w_ * sizeof(float) + w_offset);
-      L(l_next_h);
-      {
-        mov(reg_ptr_src_i, reg_tmp);
-        for (int i = 0; i < max_num_regs; ++i) {
-          vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
-          // sum anyway
-          vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
-          add(reg_ptr_src_i, sizeof(float) * block);
-        }
-        inc(reg_h);
-        add(reg_tmp, w_ * sizeof(float));
-        cmp(reg_h, h_);
-        jl(l_next_h, T_NEAR);
+    cmp(reg32_int_h, 1);
+    Label l_next_h, l_h_done;
+    jle(l_h_done, T_NEAR);
+    mov(reg_h_i, 1);
+    mov(reg_tmp, param_src);
+    add(reg_tmp, w_ * sizeof(float) + w_offset);
+    L(l_next_h);
+    {
+      mov(reg_ptr_src_i, reg_tmp);
+      for (int i = 0; i < max_num_regs; ++i) {
+        vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
+        // sum anyway
+        vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
+        add(reg_ptr_src_i, sizeof(float) * block);
       }
+      inc(reg_h_i);
+      add(reg_tmp, w_ * sizeof(float));
+      cmp(reg_h_i, reg32_int_h);
+      jl(l_next_h, T_NEAR);
     }
+    L(l_h_done);
     // save right now
     if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
-      vbroadcastss(JMM(max_num_regs), reg32_scalar);
+      mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
+      vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
+      movd(JMM(max_num_regs + 1), reg32_fp_h);
+      if (type_ == SeqPoolType::kSqrt) {
+        vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
+      }
+      vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
+      vbroadcastss(JMM(max_num_regs),
+                   JMM(max_num_regs + 2));  // TODO(TJ): fix me
     }
     offset = w_offset;
     for (int i = 0; i < max_num_regs; ++i) {
       if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
         vmulps(JMM(i), JMM(i), JMM(max_num_regs));
       }
-      vmovups(ptr[param2 + offset], JMM(i));
+      vmovups(ptr[param_dst + offset], JMM(i));
       offset += sizeof(float) * block;
     }
   }
@@ -97,47 +107,54 @@ class SeqPoolJitCode : public JitCode {
     const bool has_block4 = rest / 4 > 0;
     const bool has_block2 = (rest % 4) / 2 > 0;
     const bool has_block1 = (rest % 2) == 1;
-    if (h_ > 1) {
-      Label l_next_h;
-      mov(reg_h, 1);
-      mov(reg_tmp, param1);
-      add(reg_tmp, w_ * sizeof(float) + w_offset);
-      L(l_next_h);
-      {
-        // int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
-        // max_num_regs);
-        int reg_idx = 0;
-        mov(reg_ptr_src_i, reg_tmp);
-        if (has_block4) {
-          vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
-          add(reg_ptr_src_i, sizeof(float) * 4);
-          reg_idx++;
-        }
-        if (has_block2) {
-          vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
-          add(reg_ptr_src_i, sizeof(float) * 2);
-          reg_idx++;
-        }
-        if (has_block1) {
-          vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
-          reg_idx++;
-        }
-        PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
-                          "All heights should use same regs");
-        for (int i = 0; i < reg_idx; ++i) {
-          vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
-        }
-        inc(reg_h);
-        add(reg_tmp, w_ * sizeof(float));
-        cmp(reg_h, h_);
-        jl(l_next_h, T_NEAR);
+    cmp(reg32_int_h, 1);
+    Label l_next_h, l_h_done;
+    jle(l_h_done, T_NEAR);
+    mov(reg_h_i, 1);
+    mov(reg_tmp, param_src);
+    add(reg_tmp, w_ * sizeof(float) + w_offset);
+    L(l_next_h);
+    {
+      int reg_idx = 0;
+      mov(reg_ptr_src_i, reg_tmp);
+      if (has_block4) {
+        vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
+        add(reg_ptr_src_i, sizeof(float) * 4);
+        reg_idx++;
+      }
+      if (has_block2) {
+        vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
+        add(reg_ptr_src_i, sizeof(float) * 2);
+        reg_idx++;
+      }
+      if (has_block1) {
+        vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
+        reg_idx++;
       }
+      PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
+                        "All heights should use same regs");
+      for (int i = 0; i < reg_idx; ++i) {
+        vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
+      }
+      inc(reg_h_i);
+      add(reg_tmp, w_ * sizeof(float));
+      cmp(reg_h_i, reg32_int_h);
+      jl(l_next_h, T_NEAR);
     }
+    L(l_h_done);
     // save right now
     if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
-      vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
+      mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
+      vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
+      movd(xmm_t(max_num_regs + 1), reg32_fp_h);
+      if (type_ == SeqPoolType::kSqrt) {
+        vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
+      }
+      vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
+             xmm_t(max_num_regs + 1));
+      vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
       for (int i = 0; i < rest_used_num_regs; ++i) {
-        vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
+        vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
       }
     }
     save_rest(rest, w_offset);
@@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode {
     const bool has_block1 = (rest % 2) == 1;
     int reg_idx = reg_start;
     if (has_block4) {
-      vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
+      vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
       w_offset += sizeof(float) * 4;
       reg_idx++;
     }
     if (has_block2) {
-      vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
+      vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
       w_offset += sizeof(float) * 2;
       reg_idx++;
     }
     if (has_block1) {
-      vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
+      vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
       reg_idx++;
     }
     return reg_idx;
@@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode {
     const bool has_block1 = (rest % 2) == 1;
     int reg_idx = reg_start;
     if (has_block4) {
-      vmovups(ptr[param2 + w_offset], xmm_t(reg_idx));
+      vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
       w_offset += sizeof(float) * 4;
       reg_idx++;
     }
     if (has_block2) {
-      vmovq(ptr[param2 + w_offset], xmm_t(reg_idx));
+      vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
       w_offset += sizeof(float) * 2;
       reg_idx++;
     }
     if (has_block1) {
-      vmovss(ptr[param2 + w_offset], xmm_t(reg_idx));
+      vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
     }
   }
 
  private:
-  int h_;
   int w_;
   SeqPoolType type_;
-  reg64_t param1{abi_param1};
-  reg64_t param2{abi_param2};
-  reg64_t param3{abi_param3};
-  reg32_t reg32_scalar{r8d};
+  reg64_t param_src{abi_param1};
+  reg64_t param_dst{abi_param2};
+  reg64_t param_attr{abi_param3};
+  reg64_t reg_tmp{rax};
+
+  reg32_t reg32_int_h{r8d};
+  reg32_t reg32_fp_h{r9d};
 
-  reg64_t reg_h{r9};
-  reg64_t reg_ptr_src_i{r10};
-  reg64_t reg_tmp{r11};
+  reg64_t reg_h_i{r10};
+  reg64_t reg_ptr_src_i{r11};
 };
 
 }  // namespace gen
diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h
index 2659374650..2a7697a6f2 100644
--- a/paddle/fluid/operators/jit/kernel_base.h
+++ b/paddle/fluid/operators/jit/kernel_base.h
@@ -46,7 +46,7 @@ typedef enum {
 
 typedef enum {
   kNonePoolType = 0,
-  kSum,
+  kSum = 1,
   kAvg,
   kSqrt,
 } SeqPoolType;
@@ -121,10 +121,10 @@ struct GRUTuples {
 };
 
 typedef struct seq_pool_attr_s {
-  int h, w;
+  int h, w;  // h should always be the first one
   SeqPoolType type;
   seq_pool_attr_s() = default;
-  explicit seq_pool_attr_s(int height, int width, SeqPoolType pool_type)
+  explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1)
       : h(height), w(width), type(pool_type) {}
 } seq_pool_attr_t;
 
diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc
index db78ed8ad8..61de386886 100644
--- a/paddle/fluid/operators/jit/kernel_key.cc
+++ b/paddle/fluid/operators/jit/kernel_key.cc
@@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
 template <>
 size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
   size_t key = attr.w;
-  // TODO(TJ): support height, then removed it from key
-  constexpr int w_shift = 30;
-  return (key << act_type_shift) + static_cast<int>(attr.type) +
-         (static_cast<size_t>(attr.h) << (act_type_shift + w_shift));
+  constexpr int pool_type_shift = 3;
+  return (key << pool_type_shift) + static_cast<int>(attr.type);
 }
 
 }  // namespace jit
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index 4e19783c86..b4e9c8dd10 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
 
 template <typename T>
 void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
-  PADDLE_ENFORCE(attr->type == SeqPoolType::kSum, "Only support sum yet");
   for (int w = 0; w < attr->w; ++w) {
     const T* src = x + w;
     T* dst = y + w;
diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc
index 0f1776507a..5e05c71f40 100644
--- a/paddle/fluid/operators/jit/test.cc
+++ b/paddle/fluid/operators/jit/test.cc
@@ -439,9 +439,10 @@ void TestSeqPoolKernel() {
   // TODO(TJ): support more
   std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
   for (auto type : pool_types) {
-    for (int h : TestSizes()) {
-      for (int w : TestSizes()) {
-        const jit::seq_pool_attr_t attr(h, w, type);
+    for (int w : TestSizes()) {
+      jit::seq_pool_attr_t attr(w, type);
+      for (int h : TestSizes()) {
+        attr.h = h;
         auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
         EXPECT_TRUE(ref != nullptr);
         std::vector<T> x(h * w), yref(w);
diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index 283e2e251a..2a47502614 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
       PADDLE_ENFORCE(platform::is_cpu_place(place));
       const T* src = input.data<T>();
       T* dst = output->mutable_data<T>(place);
-      jit::seq_pool_attr_t attr;
-      attr.w = input.numel() / input.dims()[0];
-      attr.type = jit::SeqPoolType::kSum;
+      jit::seq_pool_attr_t attr(
+          static_cast<int>(input.numel() / input.dims()[0]),
+          jit::SeqPoolType::kSum);
+      auto seqpool =
+          jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
+              attr);
       for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
         attr.h = static_cast<int>(lod[i + 1] - lod[i]);
-        auto seqpool =
-            jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
-                attr);
         seqpool(src, dst, &attr);
         dst += attr.w;
         src += attr.h * attr.w;

From 123b98f417d064e780412f316f4ca43988f4d0d2 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Mon, 7 Jan 2019 06:07:23 +0000
Subject: [PATCH 10/28] refine heigth and codesize and support all pool

test=develop
---
 paddle/fluid/operators/jit/benchmark.cc   |  3 ++-
 paddle/fluid/operators/jit/gen/seqpool.cc | 27 +++++++++++-----------
 paddle/fluid/operators/jit/gen/seqpool.h  | 28 +++++++----------------
 paddle/fluid/operators/jit/test.cc        |  4 ++--
 4 files changed, 26 insertions(+), 36 deletions(-)

diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc
index 4cbada4a5b..bde2791add 100644
--- a/paddle/fluid/operators/jit/benchmark.cc
+++ b/paddle/fluid/operators/jit/benchmark.cc
@@ -192,7 +192,8 @@ void BenchGRUKernel() {
 
 template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
 void BenchSeqPoolKernel() {
-  std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
+  std::vector<jit::SeqPoolType> pool_types = {
+      jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
   for (auto type : pool_types) {
     for (int w : TestSizes()) {
       jit::seq_pool_attr_t attr(w, type);
diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc
index d651f282bf..530d24ee1f 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.cc
+++ b/paddle/fluid/operators/jit/gen/seqpool.cc
@@ -13,7 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/gen/seqpool.h"
-#include <stddef.h>  // offsetof
+#include "paddle/fluid/operators/jit/gen/act.h"  // for exp_float_consts ones
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -22,9 +22,6 @@ namespace operators {
 namespace jit {
 namespace gen {
 
-thread_local float ALIGN32_BEG float_h[1] ALIGN32_END = {
-    1.f};  // TODO(TJ): try move to private
-
 void SeqPoolJitCode::genCode() {
   constexpr int block = YMM_FLOAT_BLOCK;
   constexpr int max_num_regs = 8;
@@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() {
   int rest_num_regs = num_block % max_num_regs;
   mov(reg32_int_h, dword[param_attr]);
   if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
-    mov(reg_tmp, reinterpret_cast<size_t>(float_h));
+    mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
+    vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]);
+    mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
     fild(dword[param_attr]);
     fstp(dword[reg_tmp]);
-    mov(reg32_fp_h, dword[reg_tmp]);
+    vmovss(xmm_t(0), ptr[reg_tmp]);
+    if (type_ == SeqPoolType::kSqrt) {
+      vsqrtps(xmm_t(0), xmm_t(0));
+    }
+    vdivps(xmm_t(1), xmm_t(1), xmm_t(0));
+    vmovss(ptr[reg_tmp], xmm_t(1));
   }
   const int group_len = max_num_regs * block * sizeof(float);
   for (int g = 0; g < num_groups; ++g) {
@@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() {
   if (rest_num_regs > 0) {
     pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
   }
-
   // part of rest_w * height
   const int rest = w_ % block;
   pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
@@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
     return platform::MayIUse(platform::avx);
   }
   size_t CodeSize(const seq_pool_attr_t& attr) const override {
-    // TODO(TJ): remove attr.h when enabled height
-    bool yes =
-        attr.type == SeqPoolType::kAvg || attr.type == SeqPoolType::kSqrt;
-    return 96 /* basic */ +
-           ((attr.w / YMM_FLOAT_BLOCK + 4 /* rest */) * 2 /* for sum */
-            * (attr.h + (yes ? 3 : 1 /*for avg or sqrt*/))) *
+    return 96 +
+           ((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) *
+                4 /* load, mul and save */ +
+            256) *
                8;
   }
   std::unique_ptr<GenBase> CreateJitCode(
diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h
index c61bf27cc1..fcbbb3c84c 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.h
+++ b/paddle/fluid/operators/jit/gen/seqpool.h
@@ -16,7 +16,6 @@
 
 #include <string>
 #include "glog/logging.h"
-#include "paddle/fluid/operators/jit/gen/act.h"  // for ones
 #include "paddle/fluid/operators/jit/gen/jitcode.h"
 #include "paddle/fluid/platform/enforce.h"
 
@@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode {
                           size_t code_size = 256 * 1024,
                           void* code_ptr = nullptr)
       : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
-    if (type_ != SeqPoolType::kSum) {
+    if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
+          type_ == SeqPoolType::kSqrt)) {
       LOG(FATAL) << "Only support sum pool yet ";
     }
+    fp_h_[0] = 1.f;
     this->genCode();
   }
 
@@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode {
     L(l_h_done);
     // save right now
     if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
-      mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
-      vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
-      movd(JMM(max_num_regs + 1), reg32_fp_h);
-      if (type_ == SeqPoolType::kSqrt) {
-        vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
-      }
-      vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
-      vbroadcastss(JMM(max_num_regs),
-                   JMM(max_num_regs + 2));  // TODO(TJ): fix me
+      mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
+      vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
     }
     offset = w_offset;
     for (int i = 0; i < max_num_regs; ++i) {
@@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode {
     L(l_h_done);
     // save right now
     if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
-      mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
-      vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
-      movd(xmm_t(max_num_regs + 1), reg32_fp_h);
-      if (type_ == SeqPoolType::kSqrt) {
-        vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
-      }
-      vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
-             xmm_t(max_num_regs + 1));
-      vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
+      mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
+      vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
       for (int i = 0; i < rest_used_num_regs; ++i) {
         vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
       }
@@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode {
   }
 
  private:
+  float ALIGN32_BEG fp_h_[1] ALIGN32_END;
   int w_;
   SeqPoolType type_;
   reg64_t param_src{abi_param1};
diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc
index 5e05c71f40..30291bfef3 100644
--- a/paddle/fluid/operators/jit/test.cc
+++ b/paddle/fluid/operators/jit/test.cc
@@ -436,8 +436,8 @@ void TestGRUKernel() {
 template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
 void TestSeqPoolKernel() {
   VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  // TODO(TJ): support more
-  std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
+  std::vector<jit::SeqPoolType> pool_types = {
+      jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
   for (auto type : pool_types) {
     for (int w : TestSizes()) {
       jit::seq_pool_attr_t attr(w, type);

From 4bfa110fd893ee402ba1b052ddce7f26b257b442 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 16:28:44 +0800
Subject: [PATCH 11/28] Add no lock optimize pass

test=develop
---
 CMakeLists.txt                                |   2 +
 cmake/FindJeMalloc.cmake                      |   7 +
 cmake/generic.cmake                           |   2 +-
 paddle/fluid/framework/details/CMakeLists.txt |   2 +-
 .../fluid/framework/details/build_strategy.cc |   1 +
 paddle/fluid/framework/ir/CMakeLists.txt      |   1 +
 .../framework/ir/lock_free_optimize_pass.cc   | 360 ++++++++++++++++++
 .../framework/ir/lock_free_optimize_pass.h    | 130 +++++++
 8 files changed, 503 insertions(+), 2 deletions(-)
 create mode 100644 paddle/fluid/framework/ir/lock_free_optimize_pass.cc
 create mode 100644 paddle/fluid/framework/ir/lock_free_optimize_pass.h

diff --git a/CMakeLists.txt b/CMakeLists.txt
index d6aa8f1b85..74d869307d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License
 
+set(CMAKE_VERBOSE_MAKEFILE on)
+
 cmake_minimum_required(VERSION 3.0)
 set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
 set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
diff --git a/cmake/FindJeMalloc.cmake b/cmake/FindJeMalloc.cmake
index 7911f77c4c..b95287160b 100644
--- a/cmake/FindJeMalloc.cmake
+++ b/cmake/FindJeMalloc.cmake
@@ -19,3 +19,10 @@ find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALL
 mark_as_advanced(
   JEMALLOC_LIBRARIES
   JEMALLOC_INCLUDE_DIR)
+
+if (JEMALLOC_FOUND)
+  add_library(jemalloc::jemalloc UNKNOWN IMPORTED)
+  set_target_properties(jemalloc::jemalloc PROPERTIES
+    IMPORTED_LOCATION ${JEMALLOC_LIBRARIES}
+    INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}")
+endif()
diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index 4e31392b98..05293b8b06 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -117,7 +117,7 @@ function(common_link TARGET_NAME)
   endif()
 
   if (WITH_JEMALLOC)
-    target_link_libraries(${TARGET_NAME} ${JEMALLOC_LIBRARIES})
+    target_link_libraries(${TARGET_NAME} jemalloc::jemalloc)
   endif()
 endfunction()
 
diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index 179aa14528..c1ba6606f1 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -94,4 +94,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
         graph_viz_pass multi_devices_graph_pass
         multi_devices_graph_print_pass multi_devices_graph_check_pass
         fuse_elewise_add_act_pass multi_batch_merge_pass
-        memory_optimize_pass)
+        memory_optimize_pass lock_free_optimize_pass)
diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc
index 43c2eb7178..f65b3598b0 100644
--- a/paddle/fluid/framework/details/build_strategy.cc
+++ b/paddle/fluid/framework/details/build_strategy.cc
@@ -208,3 +208,4 @@ USE_PASS(analysis_var_pass);
 USE_PASS(sequential_execution_pass);
 USE_PASS(all_reduce_deps_pass);
 USE_PASS(modify_op_lock_and_record_event_pass);
+USE_PASS(lock_free_optimize_pass);
diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index 6d795e1e2d..6e6db3d3ef 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -31,6 +31,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
 
 pass_library(graph_to_program_pass base)
 pass_library(graph_viz_pass base)
+pass_library(lock_free_optimize_pass base)
 pass_library(fc_fuse_pass inference)
 pass_library(attention_lstm_fuse_pass inference)
 pass_library(infer_clean_graph_pass inference)
diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc
new file mode 100644
index 0000000000..96e7060aac
--- /dev/null
+++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc
@@ -0,0 +1,360 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/lock_free_optimize_pass.h"
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "paddle/fluid/framework/ir/node.h"
+#include "paddle/fluid/framework/op_proto_maker.h"
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+const char kSumGradOpName[] = "sum";
+// TODO(minqiyang): only support sgd at current time, please add
+// other optimizers later.
+const char kOptimizerType[] = "sgd";
+
+std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
+    std::unique_ptr<ir::Graph> graph) const {
+  PADDLE_ENFORCE(graph.get());
+
+  // We could collect all weights' name from SGD, where
+  // W1 <- SGD(W0, Grad0)
+  std::unordered_set<std::string> weight_var_set;
+  for (auto* node : graph->Nodes()) {
+    if (IsOpNamed(node, kOptimizerType)) {
+      auto& param_out_vars = node->Op()->Output("ParamOut");
+      PADDLE_ENFORCE(param_out_vars.size() == 1u);
+      weight_var_set.insert(param_out_vars[0]);
+    }
+  }
+
+  // find all grad's merge op via weight name, where
+  // Grad0 <- SUM(Grad1, Grad2, Grad3 ...)
+  std::unordered_set<ir::Node*> grad_sum_op_set;
+  for (ir::Node* node : graph->Nodes()) {
+    if (IsOpNamed(node, kSumGradOpName)) {
+      for (ir::Node* output : node->outputs) {
+        // strip the last grad suffix @GRAD
+        std::string var_name = output->Name();
+        const std::string suffix(kGradVarSuffix);
+        if (var_name != suffix && var_name.size() > suffix.size() &&
+            var_name.substr(var_name.size() - suffix.size()) == suffix) {
+          // if so then strip them off
+          var_name = var_name.substr(0, var_name.size() - suffix.size());
+          if (weight_var_set.find(var_name) != weight_var_set.end()) {
+            grad_sum_op_set.insert(node);
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // get the forward op and backward op pairs, where
+  // out <- forward(X, W)
+  // Grad1 <- backward(out, X')
+  // Grad0 <- SUM(Grad1, Grad2, Grad3 ...)
+  // W0 <- SGD(W1, Grad0)
+  for (ir::Node* node : grad_sum_op_set) {
+    for (ir::Node* merged_grad_var : node->outputs) {
+      // find the optimizers connected with sum op
+      if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) &&
+          merged_grad_var->outputs.size() == 1u) {
+        ir::Node* opt_node = merged_grad_var->outputs[0];
+        LOG(ERROR) << "Found opt node " << opt_node->Name();
+
+        // find the backward op connected with sum op
+        for (ir::Node* unmerged_grad_var : node->inputs) {
+          if (IsVarNameContains(unmerged_grad_var, kGradVarSuffix) &&
+              unmerged_grad_var->inputs.size() == 1u) {
+            ir::Node* backward_op = unmerged_grad_var->inputs[0];
+
+            LOG(ERROR) << "Found backward_op " << backward_op->Name();
+
+            // find the forward op related to the backward op
+            ir::Node* forward_op =
+                FindForwardOpViaBackwardOp(graph.get(), backward_op);
+
+            LOG(ERROR) << "Found forward_op " << forward_op->Name();
+
+            PADDLE_ENFORCE(forward_op);
+
+            Node* new_optimizer_node = CreateNewSGDNode(
+                graph.get(), forward_op, backward_op, node, opt_node);
+
+            PADDLE_ENFORCE(new_optimizer_node);
+          }
+        }
+      }
+    }
+  }
+
+  // Remove the sum_op and its' outputs and connected Optimizers
+  for (Node* sum_op : grad_sum_op_set) {
+    for (Node* sum_op_output : sum_op->outputs) {
+      for (Node* optimize_op : sum_op_output->outputs) {
+        if (optimize_op->NodeType() == Node::Type::kOperation &&
+            optimize_op->Name() == kOptimizerType) {
+          LOG(ERROR) << "remove optimize_op: " << optimize_op->Name() << "_"
+                     << optimize_op->id();
+          graph->RemoveNode(optimize_op);
+        }
+      }
+      LOG(ERROR) << "remove sum_op_output: " << sum_op_output->Name() << "_"
+                 << sum_op_output->id();
+      graph->RemoveNode(sum_op_output);
+    }
+    LOG(ERROR) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id();
+    graph->RemoveNode(sum_op);
+  }
+
+  for (auto* node : graph->Nodes()) {
+    for (Node* output_node : node->outputs) {
+      if (output_node->Name() == "sgd") {
+        LOG(ERROR) << "Node link to SGD: " << node->Name() << "_" << node->id()
+                   << " --> " << output_node->Name() << "_"
+                   << output_node->id();
+        for (Node* input_node : node->inputs) {
+          LOG(ERROR) << "SGD Input link: " << input_node->Name() << "_"
+                     << input_node->id() << " --> " << node->Name() << "_"
+                     << node->id();
+        }
+      }
+    }
+  }
+
+  return graph;
+}
+
+ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
+    ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node,
+    ir::Node* grad_sum_node, ir::Node* optimize_node) const {
+  PADDLE_ENFORCE(graph);
+  PADDLE_ENFORCE(forward_node);
+  PADDLE_ENFORCE(backward_node);
+  PADDLE_ENFORCE(grad_sum_node);
+  PADDLE_ENFORCE(optimize_node);
+
+  // find the grad var node between the grad sum node and backward_node
+  std::vector<ir::Node*> grad_vars =
+      FindConnectedNode(backward_node, grad_sum_node);
+  ir::Node* grad_node = nullptr;
+  for (ir::Node* node : grad_vars) {
+    if (!ir::IsControlDepVar(*node)) {
+      grad_node = node;
+    }
+  }
+  PADDLE_ENFORCE(grad_node);
+
+  // create a new SGD node
+  OpDesc* old_desc = optimize_node->Op();
+  // keep with the same block between new optimizer and the old one
+  OpDesc new_desc(*old_desc, old_desc->Block());
+  new_desc.SetInput("Param", old_desc->Input("Param"));
+  new_desc.SetInput("LearningRate", old_desc->Input("LearningRate"));
+  new_desc.SetInput("Grad", std::vector<std::string>({grad_node->Name()}));
+  new_desc.SetOutput("ParamOut", old_desc->Output("ParamOut"));
+
+  std::vector<std::string> op_role_vars = boost::get<std::vector<std::string>>(
+      new_desc.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+  // replace the second op role var, because the grad name was
+  // changed in new optimizer
+  op_role_vars.pop_back();
+  op_role_vars.push_back(grad_node->Name());
+  new_desc.SetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(),
+                   op_role_vars);
+  new_desc.SetType(kOptimizerType);
+
+  // set backward op's op role var, this will be used to
+  // set device_id in multi_device_pass
+  backward_node->Op()->SetAttr(
+      framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), op_role_vars);
+  // backward_node->Op()->SetAttr(
+  // framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), {});
+
+  // keep with the same output nodes between new optimizer and the
+  // old one
+  Node* sgd_node = graph->CreateOpNode(&new_desc);
+
+  // change all outputs of the optimize_node to the new one
+  ReplaceAllDownstreamNode(optimize_node, sgd_node);
+
+  // find connected node between forward node and optimize node
+  // and replace the optimize node to new sgd node
+  std::vector<ir::Node*> forward_opt_connected_nodes =
+      FindConnectedNode(forward_node, optimize_node);
+  for (ir::Node* node : forward_opt_connected_nodes) {
+    ReplaceUpstreamNode(node, optimize_node, sgd_node);
+  }
+
+  // find connected node between backward node and optimize node
+  // and replace the optimize node to new sgd node
+  std::vector<ir::Node*> backward_opt_connected_nodes =
+      FindConnectedNode(backward_node, optimize_node);
+  for (ir::Node* node : backward_opt_connected_nodes) {
+    ReplaceUpstreamNode(node, optimize_node, sgd_node);
+  }
+
+  // SGD must have only one param and LR in
+  PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u);
+  PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u);
+
+  // LR and weight nodes should be copied
+  for (Node* upstream_node : optimize_node->inputs) {
+    if (upstream_node->Name() == old_desc->Input("LearningRate")[0] ||
+        upstream_node->Name() == old_desc->Input("Param")[0]) {
+      ReplaceUpstreamNode(upstream_node, optimize_node, sgd_node);
+    }
+  }
+
+  LOG(ERROR) << "Create new opt node" << sgd_node->Name() << "_"
+             << sgd_node->id();
+
+  return sgd_node;
+}
+
+std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
+    ir::Node* upstream_node, ir::Node* downstream_node) const {
+  std::vector<ir::Node*> result;
+  for (ir::Node* out_node : upstream_node->outputs) {
+    for (ir::Node* in_node : downstream_node->inputs) {
+      if (in_node == out_node) {
+        result.push_back(in_node);
+      }
+    }
+  }
+
+  return result;
+}
+
+void LockFreeOptimizePass::ReplaceUpstreamNode(
+    ir::Node* upstream_node, ir::Node* old_optimizer_node,
+    ir::Node* new_optimizer_node) const {
+  PADDLE_ENFORCE(upstream_node);
+  PADDLE_ENFORCE(old_optimizer_node);
+  PADDLE_ENFORCE(new_optimizer_node);
+
+  // Remove the old_optimizer_node from upstream_node's outputs vector
+  auto& output_node_vec = upstream_node->outputs;
+  for (auto output_node_iter = output_node_vec.begin();
+       output_node_iter != output_node_vec.end();) {
+    if (*output_node_iter == old_optimizer_node) {
+      output_node_vec.erase(output_node_iter);
+      break;
+    } else {
+      ++output_node_iter;
+    }
+  }
+
+  // Add the new_optimizer_node to upstream_node's outputs vector
+  output_node_vec.emplace_back(new_optimizer_node);
+  new_optimizer_node->inputs.emplace_back(upstream_node);
+}
+
+void LockFreeOptimizePass::ReplaceAllDownstreamNode(
+    ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const {
+  PADDLE_ENFORCE(old_optimizer_node);
+  PADDLE_ENFORCE(new_optimizer_node);
+
+  for (ir::Node* downstream_node : old_optimizer_node->outputs) {
+    // Remove the old_optimizer_node from downstream_node's inputs vector
+    auto& input_node_vec = downstream_node->inputs;
+    for (auto input_node_iter = input_node_vec.begin();
+         input_node_iter != input_node_vec.end();) {
+      if (*input_node_iter == old_optimizer_node) {
+        input_node_vec.erase(input_node_iter);
+        break;
+      } else {
+        ++input_node_iter;
+      }
+    }
+
+    // Add the new_optimizer_node to downstream_node's inputs vector
+    input_node_vec.emplace_back(new_optimizer_node);
+    new_optimizer_node->outputs.emplace_back(downstream_node);
+  }
+}
+
+ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp(
+    ir::Graph* graph, ir::Node* backward_node) const {
+  PADDLE_ENFORCE(graph);
+  PADDLE_ENFORCE(backward_node);
+
+  // strip the suffix _grad of backward_node's name
+  std::string forward_op_name = backward_node->Name();
+  const std::string suffix("_grad");
+  if (forward_op_name != suffix && forward_op_name.size() > suffix.size() &&
+      forward_op_name.substr(forward_op_name.size() - suffix.size()) ==
+          suffix) {
+    // if so then strip them off
+    forward_op_name =
+        forward_op_name.substr(0, forward_op_name.size() - suffix.size());
+  } else {
+    LOG(WARNING) << "Illegal backward node's name " << backward_node->Name()
+                 << " id " << backward_node->id();
+
+    return nullptr;
+  }
+
+  for (ir::Node* node : graph->Nodes()) {
+    if (node->Name() == forward_op_name) {
+      if (node->outputs.size() == 0u) {
+        // if forward_node has no output, then it has NO grad op
+        continue;
+      }
+
+      // check whether all inputs of the backward_op that ends_with @GRAD
+      // comes from the output of forward_op is the input of the backward_op
+      bool is_related_forward_node = true;
+      for (ir::Node* backward_input : backward_node->inputs) {
+        if (IsVarNameEndsWith(backward_input, kGradVarSuffix)) {
+          bool meets_correct_output = false;
+          for (ir::Node* forward_output : node->outputs) {
+            if (forward_output->Name() + kGradVarSuffix ==
+                backward_input->Name()) {
+              meets_correct_output = true;
+              break;
+            }
+          }
+
+          if (!meets_correct_output) {
+            is_related_forward_node = false;
+            break;
+          }
+        }
+      }
+
+      if (is_related_forward_node) {
+        return node;
+      }
+    }
+  }
+
+  return nullptr;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(lock_free_optimize_pass,
+              paddle::framework::ir::LockFreeOptimizePass);
diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.h b/paddle/fluid/framework/ir/lock_free_optimize_pass.h
new file mode 100644
index 0000000000..7310f596f8
--- /dev/null
+++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.h
@@ -0,0 +1,130 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
+#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
+
+#include <string>
+#include <vector>
+
+#include <boost/algorithm/string/predicate.hpp>
+
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+class Node;
+
+/*
+* Remove the sum op of all gradients of the backward op.
+* And remove the dependecies of the optimizer related to the
+* same backward op.
+*
+* Before this pass:
+*
+* forward_op1 forward_op2
+*     |            |
+*  grad_op1    grad_op2
+*        \      /
+*          \  /
+*         sum_op
+*           |
+*         sgd_op
+*
+* After this pass:
+* forward_op1 forward_op2
+*     |            |
+*  grad_op1    grad_op2
+*     |            |
+*  sgd_op1      sgd_op2
+*
+* sgd_op1 and sgd_op2 will update the same weight which holds the same
+* memory, so we could benefits from the acceleration
+*/
+class LockFreeOptimizePass : public Pass {
+ public:
+  virtual ~LockFreeOptimizePass() {}
+
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
+
+ private:
+  // Create a new sgd node via current optimizer node
+  ir::Node* CreateNewSGDNode(ir::Graph* graph, ir::Node* forward_node,
+                             ir::Node* backward_node, ir::Node* grad_sum_node,
+                             ir::Node* optimize_node) const;
+
+  // Replace the input weight's optimizers
+  void ReplaceUpstreamNode(ir::Node* upstream_node,
+                           ir::Node* old_optimizer_node,
+                           ir::Node* new_optimizer_node) const;
+
+  // Replace the output weight's optimizers
+  void ReplaceAllDownstreamNode(ir::Node* old_optimizer_node,
+                                ir::Node* new_optimizer_node) const;
+
+  // Find all weight variables in graph
+  bool FindAllWeightVars(ir::Graph* graph) const;
+
+  // Find the forward_op node via the backward_op node
+  ir::Node* FindForwardOpViaBackwardOp(ir::Graph* graph,
+                                       ir::Node* backward_node) const;
+
+  std::vector<ir::Node*> FindConnectedNode(ir::Node* upstream_node,
+                                           ir::Node* downstream_node) const;
+
+  inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
+    PADDLE_ENFORCE(node);
+
+    return node->NodeType() == Node::Type::kOperation && node->Name() == name;
+  }
+
+  inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
+    PADDLE_ENFORCE(node);
+
+    return node->NodeType() == Node::Type::kVariable && node->Name() == name;
+  }
+
+  inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
+    PADDLE_ENFORCE(node);
+
+    return node->NodeType() == Node::Type::kVariable &&
+           boost::algorithm::ends_with(node->Name(), name);
+  }
+
+  inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
+    PADDLE_ENFORCE(node);
+
+    return node->NodeType() == Node::Type::kVariable &&
+           node->Name().find(name) != std::string::npos;
+  }
+
+  inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
+    PADDLE_ENFORCE(ctrl_dep_node);
+    PADDLE_ENFORCE(node);
+
+    return IsControlDepVar(*ctrl_dep_node) &&
+           ctrl_dep_node->inputs.size() >= 1u &&
+           ctrl_dep_node->inputs[0] == node;
+  }
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+#endif  // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_

From 00e4de04bfa0ab0b90d153694fc7c597378bac16 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 16:44:07 +0800
Subject: [PATCH 12/28] Polish code

---
 paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
index 38dfae8ad6..758432fd9e 100644
--- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
+++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
@@ -40,7 +40,7 @@ struct EmbeddingVSumFunctor {
     int64_t row_number = table_t->dims()[0];
     int64_t row_width = table_t->dims()[1];
     int64_t last_dim = output_t->dims()[1];
-    int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
+    const int64_t *ids = ids_t->data<int64_t>();
     auto ids_lod = ids_t->lod()[0];
     int64_t ids_count = ids_t->numel() / ids_lod.back();
 

From ee59e60f779749a3d431a54f68a32ebc5624df02 Mon Sep 17 00:00:00 2001
From: Tao Luo <luotao02@baidu.com>
Date: Mon, 7 Jan 2019 16:59:48 +0800
Subject: [PATCH 13/28] update mklml version

test=develop
---
 CMakeLists.txt             |  5 -----
 cmake/external/boost.cmake |  7 ++-----
 cmake/external/mklml.cmake | 24 +++++++++++-------------
 3 files changed, 13 insertions(+), 23 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 66dcef0013..8ba8554456 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -126,11 +126,6 @@ if(ANDROID OR IOS)
     add_definitions(-DPADDLE_MOBILE_INFERENCE)
 endif()
 
-if (APPLE)
-    set(WITH_MKL OFF CACHE STRING
-        "Disable MKL for building on mac" FORCE)
-endif()
-
 if (WIN32)
     set(WITH_DISTRIBUTE OFF CACHE STRING
             "Disable DISTRIBUTE when compiling for Windows" FORCE)
diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake
index 5a78a1d1b7..12412a51a0 100644
--- a/cmake/external/boost.cmake
+++ b/cmake/external/boost.cmake
@@ -23,11 +23,8 @@ set(BOOST_PROJECT       "extern_boost")
 # checked that the devtools package of CentOS 6 installs boost 1.41.0.
 # So we use 1.41.0 here.
 set(BOOST_VER           "1.41.0")
-if((NOT DEFINED BOOST_TAR) OR (NOT DEFINED BOOST_URL))
-    message(STATUS "use pre defined download url")
-    set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
-    set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
-endif()
+set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
+set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
 
 MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}")
 
diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake
index 96127e78d6..c94878b6c7 100644
--- a/cmake/external/mklml.cmake
+++ b/cmake/external/mklml.cmake
@@ -36,19 +36,17 @@ else()
 endif()
 SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
 
-IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL))
-    MESSAGE(STATUS "use pre defined download url")
-    if(WIN32)
-        SET(MKLML_VER "mklml_win_2019.0.1.20180928" CACHE STRING "" FORCE)
-        SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
-    elseif(APPLE)
-        SET(MKLML_VER "mklml_mac_2019.0.1.20180928" CACHE STRING "" FORCE)
-        SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
-    else()
-        SET(MKLML_VER "mklml_lnx_2019.0.1.20180928" CACHE STRING "" FORCE)
-        SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
-    ENDIF()
-endif()
+SET(TIME_VERSION "2019.0.1.20181227")
+if(WIN32)
+    SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE)
+    SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
+elseif(APPLE)
+    SET(MKLML_VER "mklml_mac_${TIME_VERSION}" CACHE STRING "" FORCE)
+    SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
+else()
+    SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
+    SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
+ENDIF()
 
 SET(MKLML_PROJECT       "extern_mklml")
 MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}")

From 7f45b9511aa1cf18f36709627a01a59bc1d3e661 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 22:54:01 +0800
Subject: [PATCH 14/28] Polish code

---
 paddle/fluid/framework/operator.cc | 1 +
 paddle/fluid/operators/hash_op.h   | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc
index f10da22aec..afece8e3d2 100644
--- a/paddle/fluid/framework/operator.cc
+++ b/paddle/fluid/framework/operator.cc
@@ -29,6 +29,7 @@ DECLARE_bool(benchmark);
 DEFINE_bool(check_nan_inf, false,
             "Checking whether operator produce NAN/INF or not. It will be "
             "extremely slow so please use this flag wisely.");
+DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op");
 
 namespace paddle {
 namespace framework {
diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h
index 9781bb0f45..1ed3ffe9aa 100644
--- a/paddle/fluid/operators/hash_op.h
+++ b/paddle/fluid/operators/hash_op.h
@@ -45,7 +45,7 @@ class HashKerel : public framework::OpKernel<T> {
     for (int idx = 0; idx < seq_length; ++idx) {
       for (int ihash = 0; ihash != num_hash; ++ihash) {
         output[idx * num_hash + ihash] =
-            XXH64(input, sizeof(int) * last_dim, ihash) % mod_by;
+            XXH32(input, sizeof(int) * last_dim, ihash) % mod_by;
       }
       input += last_dim;
     }

From 1bfbc0d963db26fcf72b9b53d568e0b102d50a5d Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 22:54:47 +0800
Subject: [PATCH 15/28] Polish code

test=develop
---
 paddle/fluid/framework/operator.cc | 1 -
 1 file changed, 1 deletion(-)

diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc
index afece8e3d2..f10da22aec 100644
--- a/paddle/fluid/framework/operator.cc
+++ b/paddle/fluid/framework/operator.cc
@@ -29,7 +29,6 @@ DECLARE_bool(benchmark);
 DEFINE_bool(check_nan_inf, false,
             "Checking whether operator produce NAN/INF or not. It will be "
             "extremely slow so please use this flag wisely.");
-DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op");
 
 namespace paddle {
 namespace framework {

From b76695418ad6cfe16f5fe54f9768fdf3b467a241 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 22:55:59 +0800
Subject: [PATCH 16/28] Polish log

test=develop
---
 .../framework/ir/lock_free_optimize_pass.cc   | 30 +++++++++----------
 1 file changed, 14 insertions(+), 16 deletions(-)

diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc
index 96e7060aac..92e897ca9c 100644
--- a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc
+++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc
@@ -80,7 +80,7 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
       if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) &&
           merged_grad_var->outputs.size() == 1u) {
         ir::Node* opt_node = merged_grad_var->outputs[0];
-        LOG(ERROR) << "Found opt node " << opt_node->Name();
+        VLOG(3) << "Found opt node " << opt_node->Name();
 
         // find the backward op connected with sum op
         for (ir::Node* unmerged_grad_var : node->inputs) {
@@ -88,13 +88,13 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
               unmerged_grad_var->inputs.size() == 1u) {
             ir::Node* backward_op = unmerged_grad_var->inputs[0];
 
-            LOG(ERROR) << "Found backward_op " << backward_op->Name();
+            VLOG(3) << "Found backward_op " << backward_op->Name();
 
             // find the forward op related to the backward op
             ir::Node* forward_op =
                 FindForwardOpViaBackwardOp(graph.get(), backward_op);
 
-            LOG(ERROR) << "Found forward_op " << forward_op->Name();
+            VLOG(3) << "Found forward_op " << forward_op->Name();
 
             PADDLE_ENFORCE(forward_op);
 
@@ -114,29 +114,28 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
       for (Node* optimize_op : sum_op_output->outputs) {
         if (optimize_op->NodeType() == Node::Type::kOperation &&
             optimize_op->Name() == kOptimizerType) {
-          LOG(ERROR) << "remove optimize_op: " << optimize_op->Name() << "_"
-                     << optimize_op->id();
+          VLOG(3) << "remove optimize_op: " << optimize_op->Name() << "_"
+                  << optimize_op->id();
           graph->RemoveNode(optimize_op);
         }
       }
-      LOG(ERROR) << "remove sum_op_output: " << sum_op_output->Name() << "_"
-                 << sum_op_output->id();
+      VLOG(3) << "remove sum_op_output: " << sum_op_output->Name() << "_"
+              << sum_op_output->id();
       graph->RemoveNode(sum_op_output);
     }
-    LOG(ERROR) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id();
+    VLOG(3) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id();
     graph->RemoveNode(sum_op);
   }
 
   for (auto* node : graph->Nodes()) {
     for (Node* output_node : node->outputs) {
       if (output_node->Name() == "sgd") {
-        LOG(ERROR) << "Node link to SGD: " << node->Name() << "_" << node->id()
-                   << " --> " << output_node->Name() << "_"
-                   << output_node->id();
+        VLOG(3) << "Node link to SGD: " << node->Name() << "_" << node->id()
+                << " --> " << output_node->Name() << "_" << output_node->id();
         for (Node* input_node : node->inputs) {
-          LOG(ERROR) << "SGD Input link: " << input_node->Name() << "_"
-                     << input_node->id() << " --> " << node->Name() << "_"
-                     << node->id();
+          VLOG(3) << "SGD Input link: " << input_node->Name() << "_"
+                  << input_node->id() << " --> " << node->Name() << "_"
+                  << node->id();
         }
       }
     }
@@ -226,8 +225,7 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
     }
   }
 
-  LOG(ERROR) << "Create new opt node" << sgd_node->Name() << "_"
-             << sgd_node->id();
+  VLOG(3) << "Create new opt node" << sgd_node->Name() << "_" << sgd_node->id();
 
   return sgd_node;
 }

From 5979953720ce35e5607f227d7b4c2400df0b8a35 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Mon, 7 Jan 2019 22:56:41 +0800
Subject: [PATCH 17/28] Remove debug info

test=develop
---
 CMakeLists.txt | 2 --
 1 file changed, 2 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 74d869307d..d6aa8f1b85 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License
 
-set(CMAKE_VERBOSE_MAKEFILE on)
-
 cmake_minimum_required(VERSION 3.0)
 set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
 set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})

From 7c7342bf125ef2859d1dd7628ad5a494ffe315b9 Mon Sep 17 00:00:00 2001
From: sneaxiy <sneaxiy@126.com>
Date: Tue, 8 Jan 2019 03:33:13 +0000
Subject: [PATCH 18/28] fix scope.var() test=develop

---
 paddle/fluid/framework/scope.cc | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc
index a5742dbd3d..9536185609 100644
--- a/paddle/fluid/framework/scope.cc
+++ b/paddle/fluid/framework/scope.cc
@@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) {
 }
 
 Variable* Scope::Var(std::string* name) {
-  auto new_name = string::Sprintf("%p.%d", this, vars_.size());
+  SCOPE_VARS_WRITER_LOCK
+  auto new_name = std::to_string(reinterpret_cast<uintptr_t>(this)) + "." +
+                  std::to_string(vars_.size());
   if (name != nullptr) {
     *name = new_name;
   }
-  SCOPE_VARS_WRITER_LOCK
   return VarInternal(new_name);
 }
 

From ed409ac9f4fa57dbf8785f24dde4b55714555fc4 Mon Sep 17 00:00:00 2001
From: sneaxiy <sneaxiy@126.com>
Date: Tue, 8 Jan 2019 03:37:59 +0000
Subject: [PATCH 19/28] Revert "Revert "Remove op handle lock"" test=develop

---
 paddle/fluid/operators/math/blas_impl.cu.h   | 134 +++++++++----------
 paddle/fluid/platform/cuda_helper.h          |  58 ++++++++
 paddle/fluid/platform/device_context.cc      |  18 ++-
 paddle/fluid/platform/device_context.h       |  76 ++++-------
 paddle/fluid/platform/device_context_test.cu |   3 -
 5 files changed, 159 insertions(+), 130 deletions(-)
 create mode 100644 paddle/fluid/platform/cuda_helper.h

diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h
index d35073029a..58f7be12ce 100644
--- a/paddle/fluid/operators/math/blas_impl.cu.h
+++ b/paddle/fluid/operators/math/blas_impl.cu.h
@@ -62,27 +62,19 @@ struct CUBlas<float> {
                       cudaDataType_t Atype, int lda, const void *B,
                       cudaDataType_t Btype, int ldb, const float *beta, void *C,
                       cudaDataType_t Ctype, int ldc) {
-    // Because the gcc 4.8 doesn't expand template parameter pack that
-    // appears in a lambda-expression, I can not use template parameter pack
-    // here.
-    auto cublas_call = [&]() {
+// Because the gcc 4.8 doesn't expand template parameter pack that
+// appears in a lambda-expression, I can not use template parameter pack
+// here.
 #if CUDA_VERSION >= 8000
-      VLOG(5) << "use_tensor_op_math: "
-              << (platform::TensorCoreAvailable() ? "True" : "False");
+    VLOG(5) << "use_tensor_op_math: "
+            << (dev_ctx->tensor_core_available() ? "True" : "False");
+    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
       PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
-          dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
-          lda, B, Btype, ldb, beta, C, Ctype, ldc));
+          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
+          beta, C, Ctype, ldc));
+    });
 #else
-      PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
-#endif
-    };
-
-#if CUDA_VERSION >= 9000
-    // NOTES: To use Tensor Core, we should change the cublas config,
-    // but the cublas may be hold by multi-thread.
-    dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
-#else
-    cublas_call();
+    PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
 #endif
   }
 };
@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
                       cudaDataType_t Btype, int ldb, const void *beta, void *C,
                       cudaDataType_t Ctype, int ldc,
                       cudaDataType_t computeType) {
-    auto cublas_call = [&]() {
 #if CUDA_VERSION >= 8000
-      cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
+    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
 #if CUDA_VERSION >= 9000
-      bool use_tensor_op_math = platform::TensorCoreAvailable();
-      if (use_tensor_op_math) {
-        algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
-      }
-      VLOG(5) << "use_tensor_op_math: "
-              << (use_tensor_op_math ? "True" : "False");
+    bool use_tensor_op_math = dev_ctx->tensor_core_available();
+    if (use_tensor_op_math) {
+      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
+    }
+    VLOG(5) << "use_tensor_op_math: "
+            << (use_tensor_op_math ? "True" : "False");
 #endif  // CUDA_VERSION >= 9000
 
+    dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
       PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
-          dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype,
-          lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
+          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
+          beta, C, Ctype, ldc, computeType, algo));
+    });
 #else
-      PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
-#endif
-    };
-
-#if CUDA_VERSION >= 9000
-    // NOTES: To use Tensor Core, we should change the cublas config,
-    // but the cublas may be hold by multi-thread.
-    dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
-#else
-    cublas_call();
+    PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
 #endif
   }
 };
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
                        CUDA_R_32F, N);
   } else {
 #endif  // CUDA_VERSION >= 8000
-
-    CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
-                    &alpha, B, ldb, A, lda, &beta, C, N);
+    context_.CublasCall([&](cublasHandle_t handle) {
+      CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
+                      lda, &beta, C, N);
+    });
 
 #if CUDA_VERSION >= 8000
   }
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
       CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
 #else
   // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
-  CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
-                                  N, M, K, &h_alpha, h_B, ldb, h_A, lda,
-                                  &h_beta, h_C, N);
+
+  context_.CublasCall([&](cublasHandle_t handle) {
+    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
+                                    &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
+                                    N);
+  });
 #endif  // CUDA_VERSION >= 8000
 }
 
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
   } else {
 #endif  // CUDA_VERSION >= 8000
 
-    CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K,
-                    &alpha, B, ldb, A, lda, &beta, C, ldc);
+    context_.CublasCall([&](cublasHandle_t handle) {
+      CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
+                      lda, &beta, C, ldc);
+    });
 
 #if CUDA_VERSION >= 8000
   }
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
   cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
   cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
 
-  CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
-                                  N, M, K, &alpha, B, ldb, A, lda, &beta, C,
-                                  ldc);
+  context_.CublasCall([&](cublasHandle_t handle) {
+    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
+                                    B, ldb, A, lda, &beta, C, ldc);
+  });
 }
 
 template <>
 template <typename T>
 void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
                                              T *y) const {
-  CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1);
+  context_.CublasCall([&](cublasHandle_t handle) {
+    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
+  });
 }
 
 template <>
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
                                              T beta, T *C) const {
   cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
 
-  CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1,
-                  &beta, C, 1);
+  context_.CublasCall([&](cublasHandle_t handle) {
+    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
+  });
 }
 
 template <>
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
 
 #if CUDA_VERSION >= 9010
   if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
-    auto cublas_call = [&]() {
-      cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
-      bool use_tensor_op_math = platform::TensorCoreAvailable();
-      if (use_tensor_op_math) {
-        algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
-      }
-      VLOG(5) << "use_tensor_op_math: "
-              << (use_tensor_op_math ? "True" : "False");
-
+    cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
+    bool use_tensor_op_math = context_.tensor_core_available();
+    if (use_tensor_op_math) {
+      algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
+    }
+    VLOG(5) << "use_tensor_op_math: "
+            << (use_tensor_op_math ? "True" : "False");
+
+    context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
       PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
-          context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B,
-          CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C,
-          CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo));
-    };
-    auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
-    dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
+          handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
+          strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
+          strideC, batchCount, CUDA_R_32F, algo));
+    });
   } else {
 #endif  // CUDA_VERSION >= 9010
 
-    CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA,
-                                  N, M, K, &alpha, B, ldb, strideB, A, lda,
-                                  strideA, &beta, C, ldc, strideC, batchCount);
+    context_.CublasCall([&](cublasHandle_t handle) {
+      CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
+                                    B, ldb, strideB, A, lda, strideA, &beta, C,
+                                    ldc, strideC, batchCount);
+    });
 
 #if CUDA_VERSION >= 9010
   }
diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h
new file mode 100644
index 0000000000..122de72e15
--- /dev/null
+++ b/paddle/fluid/platform/cuda_helper.h
@@ -0,0 +1,58 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <mutex>  // NOLINT
+
+#include "paddle/fluid/platform/dynload/cublas.h"
+#include "paddle/fluid/platform/macros.h"
+
+#if CUDA_VERSION < 9000
+enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
+#endif
+
+namespace paddle {
+namespace platform {
+
+class CublasHandleHolder {
+ public:
+  CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
+    PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
+    PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
+#if CUDA_VERSION >= 9000
+    if (math_type == CUBLAS_TENSOR_OP_MATH) {
+      PADDLE_ENFORCE(
+          dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
+    }
+#endif
+  }
+
+  ~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
+
+  template <typename Callback>
+  inline void Call(Callback &&callback) const {
+    std::lock_guard<std::mutex> guard(mtx_);
+    callback(handle_);
+  }
+
+ private:
+  DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
+
+  cublasHandle_t handle_;
+  mutable std::mutex mtx_;
+};
+
+}  // namespace platform
+}  // namespace paddle
diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index 022afb686b..be7f4949d6 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
   eigen_stream_.reset(new EigenCudaStreamDevice());
   eigen_stream_->Reinitialize(&stream_, place);
   eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
-  PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
-  PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
+  cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
+
+  if (TensorCoreAvailable()) {
+#if CUDA_VERSION >= 9000
+    cublas_tensor_core_handle_.reset(
+        new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
+#endif
+  }
+
   if (dynload::HasCUDNN()) {
     cudnn_holder_.reset(new CudnnHolder(&stream_, place));
   }
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
   SetDeviceId(place_.device);
   Wait();
   WaitStreamCallback();
-  PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
+  cublas_handle_.reset();
+  cublas_tensor_core_handle_.reset();
   eigen_stream_.reset();
   eigen_device_.reset();
   PADDLE_ENFORCE(cudaStreamDestroy(stream_));
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
   return eigen_device_.get();
 }
 
-cublasHandle_t CUDADeviceContext::cublas_handle() const {
-  return cublas_handle_;
+bool CUDADeviceContext::tensor_core_available() const {
+  return cublas_tensor_core_handle_ != nullptr;
 }
 
 cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h
index 7e87580189..c81d17380c 100644
--- a/paddle/fluid/platform/device_context.h
+++ b/paddle/fluid/platform/device_context.h
@@ -20,6 +20,7 @@ limitations under the License. */
 #include "paddle/fluid/memory/malloc.h"
 #include "paddle/fluid/platform/temporary_allocator.h"
 #ifdef PADDLE_WITH_CUDA
+#include "paddle/fluid/platform/cuda_helper.h"
 #include "paddle/fluid/platform/dynload/cublas.h"
 #include "paddle/fluid/platform/dynload/cudnn.h"
 #include "paddle/fluid/platform/gpu_info.h"
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
   std::unique_ptr<std::lock_guard<std::mutex>> guard_;
 };
 
-#if CUDA_VERSION >= 9000
-class ScopedCublasMathMode {
- public:
-  ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
-      : handle_(handle) {
-    need_reset = false;
-    PADDLE_ENFORCE(
-        platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
-        "Failed to get old cublas math mode");
-    if (old_math_mode_ != new_math_mode) {
-      PADDLE_ENFORCE(
-          platform::dynload::cublasSetMathMode(handle_, new_math_mode),
-          "Failed to set old cublas math mode");
-      need_reset = true;
-    }
-  }
-
-  ~ScopedCublasMathMode() {
-    if (need_reset) {
-      PADDLE_ENFORCE(
-          platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
-          "Failed to set old cublas math mode");
-    }
-  }
-
- private:
-  cublasHandle_t handle_;
-  cublasMath_t old_math_mode_;
-  bool need_reset;
-};
-
-#endif
-
 class CUDADeviceContext : public DeviceContext {
  public:
   explicit CUDADeviceContext(CUDAPlace place);
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
   /*! \brief  Return eigen device in the device context. */
   Eigen::GpuDevice* eigen_device() const;
 
-  /*! \brief  Return cublas handle in the device context. */
-  cublasHandle_t cublas_handle() const;
+  /*! \brief  Call cublas function safely. */
+  template <typename Callback>
+  inline void CublasCall(Callback&& callback) const {
+    cublas_handle_->Call(std::forward<Callback>(callback));
+  }
+
+  /*! \brief  Check whether tensor core is supported */
+  bool tensor_core_available() const;
+
+  /*! \brief  Call cublas function with Tensor Core safely. If
+      Tensor Core is not available, use DEFAULT_MATH instead. */
+  template <typename Callback>
+  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
+    if (cublas_tensor_core_handle_) {
+      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
+    } else {
+      cublas_handle_->Call(std::forward<Callback>(callback));
+    }
+  }
 
   /*! \brief  Return cudnn  handle in the device context. */
   cudnnHandle_t cudnn_handle() const;
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
 
   template <typename Callback>
   void RecordEvent(cudaEvent_t ev, Callback callback) {
-    std::lock_guard<std::mutex> guard(mtx_);
     callback();
     PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
   }
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
 
   void WaitStreamCallback() const { callback_manager_->Wait(); }
 
-#if CUDA_VERSION >= 9000
-  /*! \brief CublasCall may need to change cublas's config,
-   *  but the cublas may be hold by multi-thread, so we should
-   *  add lock here. */
-  template <typename Callback>
-  void CublasCall(Callback callback, cublasMath_t new_math) {
-    std::lock_guard<std::mutex> guard(cublas_mtx_);
-    ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
-    callback();
-  }
-#endif
-
  private:
   CUDAPlace place_;
 
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
   std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
   std::unique_ptr<CudnnHolder> cudnn_holder_;
   cudaStream_t stream_;
-  cublasHandle_t cublas_handle_;
+
+  std::unique_ptr<CublasHandleHolder> cublas_handle_;
+  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
 
   int compute_capability_;
   int runtime_version_;
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
   int multi_process_;
   int max_threads_per_mp_;
 
-  mutable std::mutex mtx_;
-
   // StreamCallbackManager is thread-safe
   std::unique_ptr<StreamCallbackManager> callback_manager_;
 
-  mutable std::mutex cublas_mtx_;
+  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
 };
 
 template <>
diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu
index 171d2979a0..5b3aa98efb 100644
--- a/paddle/fluid/platform/device_context_test.cu
+++ b/paddle/fluid/platform/device_context_test.cu
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
     ASSERT_NE(nullptr, gpu_device);
     cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
     ASSERT_NE(nullptr, cudnn_handle);
-    cublasHandle_t cublas_handle = device_context->cublas_handle();
-    ASSERT_NE(nullptr, cublas_handle);
-    ASSERT_NE(nullptr, device_context->stream());
     delete device_context;
   }
 }

From 49c31e5da409f9af01182ea74a91d605e3ca9747 Mon Sep 17 00:00:00 2001
From: Tao Luo <luotao02@baidu.com>
Date: Mon, 7 Jan 2019 20:31:20 +0800
Subject: [PATCH 20/28] disable mkl for mac

test=develop
---
 CMakeLists.txt             |  5 +++++
 cmake/external/mklml.cmake | 30 +++++++++++++++---------------
 2 files changed, 20 insertions(+), 15 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8ba8554456..66dcef0013 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -126,6 +126,11 @@ if(ANDROID OR IOS)
     add_definitions(-DPADDLE_MOBILE_INFERENCE)
 endif()
 
+if (APPLE)
+    set(WITH_MKL OFF CACHE STRING
+        "Disable MKL for building on mac" FORCE)
+endif()
+
 if (WIN32)
     set(WITH_DISTRIBUTE OFF CACHE STRING
             "Disable DISTRIBUTE when compiling for Windows" FORCE)
diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake
index c94878b6c7..43322a257a 100644
--- a/cmake/external/mklml.cmake
+++ b/cmake/external/mklml.cmake
@@ -16,6 +16,12 @@ IF(NOT ${WITH_MKLML})
   return()
 ENDIF(NOT ${WITH_MKLML})
 
+IF(APPLE)
+    MESSAGE(WARNING "Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF.")
+    SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in MacOS" FORCE)
+    return()
+ENDIF()
+
 INCLUDE(ExternalProject)
 SET(MKLML_DST_DIR       "mklml")
 SET(MKLML_INSTALL_ROOT  "${THIRD_PARTY_PATH}/install")
@@ -23,29 +29,23 @@ SET(MKLML_INSTALL_DIR   ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
 SET(MKLML_ROOT          ${MKLML_INSTALL_DIR})
 SET(MKLML_INC_DIR       ${MKLML_ROOT}/include)
 SET(MKLML_LIB_DIR       ${MKLML_ROOT}/lib)
-if(WIN32)
+SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
+
+SET(TIME_VERSION "2019.0.1.20181227")
+IF(WIN32)
+    SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE)
+    SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
     SET(MKLML_LIB                 ${MKLML_LIB_DIR}/mklml.lib)
     SET(MKLML_IOMP_LIB            ${MKLML_LIB_DIR}/libiomp5md.lib)
     SET(MKLML_SHARED_LIB          ${MKLML_LIB_DIR}/mklml.dll)
     SET(MKLML_SHARED_IOMP_LIB     ${MKLML_LIB_DIR}/libiomp5md.dll)
-else()
+ELSE()  
+    SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
+    SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
     SET(MKLML_LIB                 ${MKLML_LIB_DIR}/libmklml_intel.so)
     SET(MKLML_IOMP_LIB            ${MKLML_LIB_DIR}/libiomp5.so)
     SET(MKLML_SHARED_LIB          ${MKLML_LIB_DIR}/libmklml_intel.so)
     SET(MKLML_SHARED_IOMP_LIB     ${MKLML_LIB_DIR}/libiomp5.so)
-endif()
-SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
-
-SET(TIME_VERSION "2019.0.1.20181227")
-if(WIN32)
-    SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE)
-    SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
-elseif(APPLE)
-    SET(MKLML_VER "mklml_mac_${TIME_VERSION}" CACHE STRING "" FORCE)
-    SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
-else()
-    SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
-    SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
 ENDIF()
 
 SET(MKLML_PROJECT       "extern_mklml")

From 7b7d0d0caf85fc2d104ac285cfa367ff46490fa1 Mon Sep 17 00:00:00 2001
From: minqiyang <minqiyang@baidu.com>
Date: Tue, 8 Jan 2019 13:30:09 +0800
Subject: [PATCH 21/28] Change hash function back

test=develop
---
 paddle/fluid/operators/hash_op.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h
index 1ed3ffe9aa..9781bb0f45 100644
--- a/paddle/fluid/operators/hash_op.h
+++ b/paddle/fluid/operators/hash_op.h
@@ -45,7 +45,7 @@ class HashKerel : public framework::OpKernel<T> {
     for (int idx = 0; idx < seq_length; ++idx) {
       for (int ihash = 0; ihash != num_hash; ++ihash) {
         output[idx * num_hash + ihash] =
-            XXH32(input, sizeof(int) * last_dim, ihash) % mod_by;
+            XXH64(input, sizeof(int) * last_dim, ihash) % mod_by;
       }
       input += last_dim;
     }

From 23bdd0a223cc3e88c62fb8f48155c83455c9fede Mon Sep 17 00:00:00 2001
From: superjomn <yanchunwei@outlook.com>
Date: Tue, 8 Jan 2019 15:11:48 +0800
Subject: [PATCH 22/28] fix analysis_tester bug

test=develop
---
 paddle/fluid/inference/analysis/analyzer_tester.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc
index f84e1ab6b8..4c84d02d86 100644
--- a/paddle/fluid/inference/analysis/analyzer_tester.cc
+++ b/paddle/fluid/inference/analysis/analyzer_tester.cc
@@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) {
        i++) {
     LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i]
               << " result: " << result[i];
-    PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
-                   result[i]);
+    EXPECT_NEAR(static_cast<float*>(outputs.front().data.data())[i], result[i],
+                1e-3);
   }
 }
 

From 69fd3fdb5206045cfcee90d98b52cf070f1dcae1 Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Tue, 8 Jan 2019 09:11:39 +0000
Subject: [PATCH 23/28] fix debug build error

test=develop
---
 paddle/fluid/inference/analysis/passes/CMakeLists.txt | 1 +
 1 file changed, 1 insertion(+)

diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt
index d3ea511d8f..add9b70f2c 100644
--- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt
+++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt
@@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps}
         ir_graph_build_pass
         ir_analysis_pass
         analysis_passes
+        subgraph_detector
         CACHE INTERNAL "")

From bc205ef37453e0f7ab1f74abb123c3367ceee3c7 Mon Sep 17 00:00:00 2001
From: sneaxiy <sneaxiy@126.com>
Date: Tue, 8 Jan 2019 10:28:01 +0000
Subject: [PATCH 24/28] fix same name func test=develop

---
 paddle/fluid/framework/var_type_traits.cc      | 8 +++++---
 paddle/fluid/framework/var_type_traits.h       | 4 ++--
 paddle/fluid/framework/var_type_traits_test.cc | 9 +++++----
 3 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc
index c3c5bab23b..a37b1fbab8 100644
--- a/paddle/fluid/framework/var_type_traits.cc
+++ b/paddle/fluid/framework/var_type_traits.cc
@@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder {
 
 }  // namespace detail
 
-const std::type_index &ToTypeIndex(int var_id) {
+const std::type_index &VarTraitIdToTypeIndex(int var_id) {
   return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
 }
 
-const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
+const char *ToTypeName(int var_id) {
+  return VarTraitIdToTypeIndex(var_id).name();
+}
 
-int ToTypeId(const std::type_index &type) {
+int TypeIndexToVarTraitId(const std::type_index &type) {
   return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
 }
 
diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h
index cc68cf2ab8..733542e497 100644
--- a/paddle/fluid/framework/var_type_traits.h
+++ b/paddle/fluid/framework/var_type_traits.h
@@ -66,8 +66,8 @@ namespace paddle {
 namespace framework {
 
 const char *ToTypeName(int var_id);
-const std::type_index &ToTypeIndex(int var_id);
-int ToTypeId(const std::type_index &type);
+const std::type_index &VarTraitIdToTypeIndex(int var_id);
+int TypeIndexToVarTraitId(const std::type_index &type);
 
 namespace detail {
 
diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc
index 00840d634d..a47275e1ca 100644
--- a/paddle/fluid/framework/var_type_traits_test.cc
+++ b/paddle/fluid/framework/var_type_traits_test.cc
@@ -45,10 +45,11 @@ struct TypeIndexChecker {
     constexpr auto kId = VarTypeTrait<Type>::kId;
     std::type_index actual_type(typeid(Type));
     EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
-    EXPECT_EQ(ToTypeIndex(kId), actual_type);
-    EXPECT_EQ(ToTypeId(actual_type), kId);
-    EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type);
-    EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId);
+    EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type);
+    EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId);
+    EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)),
+              actual_type);
+    EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId);
 
     EXPECT_TRUE(var_id_set->count(kId) == 0);              // NOLINT
     EXPECT_TRUE(type_index_set->count(actual_type) == 0);  // NOLINT

From 55a0672378329764a1b1429d9cfc8def91317e63 Mon Sep 17 00:00:00 2001
From: chengduo <zhaochengduo@baidu.com>
Date: Tue, 8 Jan 2019 05:20:48 -0600
Subject: [PATCH 25/28] fix compute_75 of cuda_cmake (#15209)

test=develop
---
 cmake/cuda.cmake | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake
index 10ecdf0ea8..16432ce2b8 100644
--- a/cmake/cuda.cmake
+++ b/cmake/cuda.cmake
@@ -2,9 +2,11 @@ if(NOT WITH_GPU)
     return()
 endif()
 
-set(paddle_known_gpu_archs "30 35 50 52 60 61 70 75")
+set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
 set(paddle_known_gpu_archs7 "30 35 50 52")
 set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
+set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
+set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
 
 ######################################################################################
 # A function for automatic detection of GPUs installed  (if autodetection is enabled)
@@ -155,6 +157,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
   # warning for now.
   list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
   add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
+elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
+  set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
+  list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
+  list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
+  add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
+elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
+  set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
+  list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
+  list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
+  add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
 endif()
 
 include_directories(${CUDA_INCLUDE_DIRS})

From e4184008a4e4aa60fbd21d43209256ec1114186f Mon Sep 17 00:00:00 2001
From: mozga-intel <mateusz.ozga@intel.com>
Date: Tue, 8 Jan 2019 16:37:03 +0100
Subject: [PATCH 26/28] PADDLE_WITH_NGRAPH was removed from the code
 test=develop

---
 paddle/fluid/operators/ngraph/ops/binary_unnary_op.h      | 2 --
 paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h | 2 --
 paddle/fluid/operators/ngraph/ops/fill_constant_op.h      | 2 --
 paddle/fluid/operators/ngraph/ops/mean_op.h               | 2 --
 paddle/fluid/operators/ngraph/ops/mul_op.h                | 2 --
 paddle/fluid/operators/ngraph/ops/scale_op.h              | 2 --
 paddle/fluid/operators/ngraph/ops/top_k_op.h              | 2 --
 7 files changed, 14 deletions(-)

diff --git a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h
index 6610380fcf..0c0d25d0cd 100644
--- a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h
+++ b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <string>
@@ -48,4 +47,3 @@ static void BuildUnaryNode(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif
diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h
index 15fbd58b02..8f5092963c 100644
--- a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h
+++ b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <string>
@@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif
diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h
index 5eff69e7b1..406a4314f8 100644
--- a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h
+++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <string>
@@ -58,4 +57,3 @@ void BuildFillConstantNode(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif
diff --git a/paddle/fluid/operators/ngraph/ops/mean_op.h b/paddle/fluid/operators/ngraph/ops/mean_op.h
index 7fcf8f09cd..4c44bc4c11 100644
--- a/paddle/fluid/operators/ngraph/ops/mean_op.h
+++ b/paddle/fluid/operators/ngraph/ops/mean_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <functional>
@@ -65,4 +64,3 @@ void BuildMeanGradNode(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif
diff --git a/paddle/fluid/operators/ngraph/ops/mul_op.h b/paddle/fluid/operators/ngraph/ops/mul_op.h
index 9e12e5d7c3..4a6cbebe24 100644
--- a/paddle/fluid/operators/ngraph/ops/mul_op.h
+++ b/paddle/fluid/operators/ngraph/ops/mul_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <string>
@@ -131,4 +130,3 @@ static void BuildMulGradNode(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif
diff --git a/paddle/fluid/operators/ngraph/ops/scale_op.h b/paddle/fluid/operators/ngraph/ops/scale_op.h
index 24ab0702aa..91a57d0be6 100644
--- a/paddle/fluid/operators/ngraph/ops/scale_op.h
+++ b/paddle/fluid/operators/ngraph/ops/scale_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <string>
@@ -38,4 +37,3 @@ void BuildScaleNode(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif
diff --git a/paddle/fluid/operators/ngraph/ops/top_k_op.h b/paddle/fluid/operators/ngraph/ops/top_k_op.h
index 2b7254497c..ea66953a12 100644
--- a/paddle/fluid/operators/ngraph/ops/top_k_op.h
+++ b/paddle/fluid/operators/ngraph/ops/top_k_op.h
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
-#ifdef PADDLE_WITH_NGRAPH
 #pragma once
 
 #include <string>
@@ -48,4 +47,3 @@ void BuildTopKNode(
 }  // namespace ngraphs
 }  // namespace operators
 }  // namespace paddle
-#endif

From a037378fdb96773f44e0c12c14d2119b7e76996a Mon Sep 17 00:00:00 2001
From: qingqing01 <dangqingqing@baidu.com>
Date: Wed, 9 Jan 2019 10:16:40 +0800
Subject: [PATCH 27/28] Fix error with cuDNN version less than 7.1. (#15219)

Since conv_fusion_op is not exposed into Python, remote the env flag in __init__.py
test=develop
---
 python/paddle/fluid/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py
index f9f3807b15..2c17716500 100644
--- a/python/paddle/fluid/__init__.py
+++ b/python/paddle/fluid/__init__.py
@@ -155,7 +155,7 @@ def __bootstrap__():
             'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
             'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
             'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
-            'cudnn_exhaustive_search_times', 'sync_nccl_allreduce'
+            'sync_nccl_allreduce'
         ]
 
     core.init_gflags([sys.argv[0]] +

From f23a257e905e61f513c2a68cdfd9fb39d8ff16db Mon Sep 17 00:00:00 2001
From: Tao Luo <luotao02@baidu.com>
Date: Wed, 9 Jan 2019 11:26:14 +0800
Subject: [PATCH 28/28] use the new MKLDNN repo url

test=develop
---
 cmake/external/mkldnn.cmake | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake
index a9b99e9ab8..03f0dee859 100644
--- a/cmake/external/mkldnn.cmake
+++ b/cmake/external/mkldnn.cmake
@@ -55,7 +55,7 @@ ExternalProject_Add(
     ${MKLDNN_PROJECT}
     ${EXTERNAL_PROJECT_LOG_ARGS}
     DEPENDS             ${MKLDNN_DEPENDS}
-    GIT_REPOSITORY      "https://github.com/01org/mkl-dnn.git"
+    GIT_REPOSITORY      "https://github.com/intel/mkl-dnn.git"
     GIT_TAG             "830a10059a018cd2634d94195140cf2d8790a75a"
     PREFIX              ${MKLDNN_SOURCES_DIR}
     UPDATE_COMMAND      ""