diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index b9491c953f..ad19d729eb 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -174,7 +174,7 @@ else()
   cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
 endif()
 
-target_link_libraries(executor garbage_collector)
+target_link_libraries(executor garbage_collector while_op_helper)
 
 cc_library(parallel_executor SRCS parallel_executor.cc DEPS
         threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index dc308fd259..9f06455ea5 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
 cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
 cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
 cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
-cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
+cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
+cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
 cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
 
 cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h
index 1e3dbb1e44..e98b16e6b3 100644
--- a/paddle/fluid/framework/details/computation_op_handle.h
+++ b/paddle/fluid/framework/details/computation_op_handle.h
@@ -14,6 +14,7 @@
 
 #pragma once
 
+#include <memory>
 #include <string>
 #include <vector>
 
@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase {
   ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
                       size_t scope_idx);
 
+  OperatorBase *GetOp() { return op_.get(); }
+
   std::string Name() const override;
 
   const Scope *GetScope() const { return scope_; }
diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc
index 03fbfd7f24..dbc90737f2 100644
--- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc
+++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc
@@ -12,6 +12,10 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <memory>
+#include <unordered_set>
+#include <utility>
+
 #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
 #include "paddle/fluid/framework/lod_tensor_array.h"
 #include "paddle/fluid/framework/scope.h"
@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
     }
   }
 #endif
+  PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
 }
 
 EagerDeletionOpHandle::~EagerDeletionOpHandle() {
@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
 std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
 
 void EagerDeletionOpHandle::RunImpl() {
-  auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
+  Scope *exec_scope = nullptr;
   std::deque<std::shared_ptr<memory::Allocation>> garbages;
   for (auto &name : var_names_) {
     auto it = ref_cnts_->find(name);
-    // Var not found, not reference count has not decreased to 0
+    // Reference count has not decreased to 0
     if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
       continue;
     }
 
+    if (!exec_scope) {
+      exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
+    }
+
+    // Var not found
     auto *var = exec_scope->FindVar(name);
     if (var == nullptr) {
       continue;
diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc
index 4e42d0b497..377bb915e0 100644
--- a/paddle/fluid/framework/details/eager_deletion_pass.cc
+++ b/paddle/fluid/framework/details/eager_deletion_pass.cc
@@ -12,20 +12,173 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <algorithm>
+#include <functional>
 #include <queue>
 #include <string>
+#include <tuple>
 #include <vector>
 
 #include "paddle/fluid/framework/details/computation_op_handle.h"
 #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
-#include "paddle/fluid/framework/details/eager_deletion_pass.h"
 #include "paddle/fluid/framework/details/multi_devices_helper.h"
 #include "paddle/fluid/framework/ir/graph_helper.h"
 
+DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
+              "Fraction of eager deletion. If less than 1.0, all variables in "
+              "the program would be sorted according to its memory size, and "
+              "only the FLAGS_memory_fraction_of_eager_deletion of the largest "
+              "variables would be deleted.");
+
 namespace paddle {
 namespace framework {
 namespace details {
 
+// op -> variables which can be deleted after op runs
+using OpToVarNameSetMap =
+    std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
+
+// Check whether the variable is LoDTensor based on static VarDesc info
+static bool IsLoDTensor(VarDesc *var) {
+  return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
+}
+
+// Get memory size of LoDTensor
+static int64_t GetMemorySize(
+    const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
+    const std::string &var_name) {
+  auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
+  PADDLE_ENFORCE_NOT_NULL(var_desc);
+  PADDLE_ENFORCE(IsLoDTensor(var_desc));
+  auto dims = var_desc->GetShape();
+  return SizeOfType(var_desc->GetDataType()) *
+         std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
+                         std::multiplies<int64_t>());
+}
+
+// Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
+// SelectedRows, LoDTensorArray)
+// Since partial GC is based on static analysis of memory size of each variable
+// So we should skip SelectedRows and LoDTensorArray here
+static void SplitIntoLoDTensorAndNonLoDTensorVars(
+    const OpToVarNameSetMap &m, const GraphVars &vars,
+    OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
+  lod_tensors->clear();
+  other_vars->clear();
+
+  for (auto &op_vars_pair : m) {
+    for (auto &var_name : op_vars_pair.second) {
+      auto *var_desc = TryGetLatestVarDesc(
+          vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
+      if (IsLoDTensor(var_desc)) {
+        (*lod_tensors)[op_vars_pair.first].insert(var_name);
+      } else {
+        (*other_vars)[op_vars_pair.first].insert(var_name);
+      }
+    }
+  }
+}
+
+struct GCVarInfo {
+  GCVarInfo(const std::string &name, int64_t memory_size,
+            ComputationOpHandle *op, size_t scope_idx)
+      : name_(name),
+        memory_size_(memory_size),
+        op_(op),
+        scope_idx_(scope_idx) {}
+
+  std::string name_;         // variable name
+  int64_t memory_size_;      // memory size
+  ComputationOpHandle *op_;  // op after which the variable could be deleted
+  size_t scope_idx_;         // scope index where the variable locates
+
+  int64_t AbsMemorySize() const { return std::abs(memory_size_); }
+};
+
+// Delete delete_lod_tensor_only is not used currently
+static OpToVarNameSetMap ShrinkGCVars(
+    const OpToVarNameSetMap &m, const GraphVars &vars,
+    const std::vector<platform::Place> &places, double fraction_of_memory_size,
+    bool delete_lod_tensor_only = false) {
+  // Do not perform gc when fraction_of_memory_size = 0
+  if (fraction_of_memory_size <= 0.0) return {};
+
+  /**
+   * Step 1: Split all variables into LoDTensor and Non-LoDTensor.
+   * We can only calculate memory size of LoDTensors
+   */
+  OpToVarNameSetMap lod_tensors, other_vars;
+  SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
+
+  // Perform complete gc when fraction_of_memory_size >= 1
+  if (fraction_of_memory_size >= 1.0) {
+    return delete_lod_tensor_only ? lod_tensors : m;
+  }
+
+  /**
+   * Step 2: build GCVarInfos, and calculate total memory sizes of each device
+   */
+
+  // place -> variable info (name, memory size, place, scope_idx)
+  std::map<platform::Place, std::vector<GCVarInfo>> place_to_vars;
+
+  // place -> total memory sizes
+  std::map<platform::Place, int64_t> place_to_size;
+  for (auto &op_vars_pair : lod_tensors) {
+    auto *op = op_vars_pair.first;
+    auto &var_names = op_vars_pair.second;
+    auto scope_idx = op->GetScopeIdx();
+    auto &place = places[scope_idx];
+
+    for (auto &var_name : var_names) {
+      auto var_size = GetMemorySize(vars[scope_idx], var_name);
+      GCVarInfo var_info(var_name, var_size, op, scope_idx);
+      place_to_size[place] += var_info.AbsMemorySize();
+      place_to_vars[place].emplace_back(std::move(var_info));
+    }
+  }
+
+  /**
+   * Step 3: sort GCVarInfos, and only delete the largest variables.
+   */
+  OpToVarNameSetMap partial_vars;
+  for (auto &place_to_var_pair : place_to_vars) {
+    auto &place = place_to_var_pair.first;
+    auto &gc_vars = place_to_var_pair.second;
+    std::sort(gc_vars.begin(), gc_vars.end(),
+              [](const GCVarInfo &var1, const GCVarInfo &var2) {
+                return var1.AbsMemorySize() > var2.AbsMemorySize();
+              });
+
+    int64_t accumulated_size = 0;
+    int64_t size_threshold =
+        static_cast<int64_t>(fraction_of_memory_size * place_to_size[place]);
+    for (size_t i = 0; i < gc_vars.size() && accumulated_size < size_threshold;
+         ++i) {
+      partial_vars[gc_vars[i].op_].insert(gc_vars[i].name_);
+      accumulated_size += gc_vars[i].AbsMemorySize();
+    }
+  }
+
+  /**
+   * Step 4: Combine other vars (SelectedRows, LoDTensorArray)
+   */
+  if (!delete_lod_tensor_only) {
+    for (auto &op_vars_pair : other_vars) {
+      partial_vars[op_vars_pair.first].insert(op_vars_pair.second.begin(),
+                                              op_vars_pair.second.end());
+    }
+  }
+
+  return partial_vars;
+}
+
+class EagerDeletionPass : public ir::Pass {
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(
+      std::unique_ptr<ir::Graph> graph) const override;
+};
+
 std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   auto &ref_cnts =
@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
 
   // a reverse map of last_live_ops
   //   i.e., last op --> variable names which can be deleted.
-  std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
-      op_vars_map;
-
+  OpToVarNameSetMap op_vars_map;
   for (auto &var_ops_map : last_live_ops) {
     for (auto &var_ops_pair : var_ops_map) {
       const std::string &var_name = var_ops_pair.first;
@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
     }
   }
 
+  op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
+                             FLAGS_memory_fraction_of_eager_deletion);
+
   for (auto &pair : op_vars_map) {
     auto *op = pair.first;
     auto &var_names = pair.second;
@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
     eager_deletion_op->AddOutput(dummy_leaf);
   }
 
+  VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
+           << FLAGS_memory_fraction_of_eager_deletion;
   VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
-  return graph;
+
+  auto while_op_eager_deletion_pass =
+      ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
+  return while_op_eager_deletion_pass->Apply(std::move(graph));
 }
 
 }  // namespace details
@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
     .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
     .RequirePassAttr(paddle::framework::details::kAllPlaces)
     .RequirePassAttr(paddle::framework::details::kGarbageCollector);
+
+USE_PASS(while_op_eager_deletion_pass);
diff --git a/paddle/fluid/framework/details/eager_deletion_pass.h b/paddle/fluid/framework/details/eager_deletion_pass.h
deleted file mode 100644
index d7a7a9709d..0000000000
--- a/paddle/fluid/framework/details/eager_deletion_pass.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#pragma once
-
-#include "paddle/fluid/framework/ir/graph.h"
-#include "paddle/fluid/framework/ir/pass.h"
-
-namespace paddle {
-namespace framework {
-namespace details {
-
-class EagerDeletionPass : public ir::Pass {
- protected:
-  std::unique_ptr<ir::Graph> ApplyImpl(
-      std::unique_ptr<ir::Graph> graph) const override;
-};
-
-}  // namespace details
-}  // namespace framework
-}  // namespace paddle
diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc
index 13a042d8e6..6092143449 100644
--- a/paddle/fluid/framework/details/reference_count_pass.cc
+++ b/paddle/fluid/framework/details/reference_count_pass.cc
@@ -12,9 +12,13 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <memory>
 #include <queue>
 #include <string>
 #include <type_traits>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
 #include <vector>
 
 #include "paddle/fluid/framework/details/computation_op_handle.h"
@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
   return shrink_func(computation_op);
 }
 
-static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
-  VarDesc *var_desc = nullptr;
-  std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
-    var_desc = var_handle->Node()->Var();
-    return var_desc != nullptr;
-  });
-  return var_desc;
-}
-
 std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.cc b/paddle/fluid/framework/details/reference_count_pass_helper.cc
index 89bd08c2d0..94de0e6ab0 100644
--- a/paddle/fluid/framework/details/reference_count_pass_helper.cc
+++ b/paddle/fluid/framework/details/reference_count_pass_helper.cc
@@ -13,9 +13,22 @@
 // limitations under the License.
 
 #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
+#include "paddle/fluid/framework/details/var_handle.h"
+#include "paddle/fluid/framework/var_desc.h"
 
 namespace paddle {
 namespace framework {
-namespace details {}  // namespace details
+namespace details {
+
+VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
+  VarDesc *var_desc = nullptr;
+  std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
+    var_desc = var_handle->Node()->Var();
+    return var_desc != nullptr;
+  });
+  return var_desc;
+}
+
+}  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.h b/paddle/fluid/framework/details/reference_count_pass_helper.h
index 1c083dbf00..ce700119c5 100644
--- a/paddle/fluid/framework/details/reference_count_pass_helper.h
+++ b/paddle/fluid/framework/details/reference_count_pass_helper.h
@@ -16,6 +16,7 @@
 
 #include <atomic>
 #include <map>
+#include <memory>
 #include <string>
 #include <unordered_map>
 #include <unordered_set>
@@ -25,6 +26,10 @@
 
 namespace paddle {
 namespace framework {
+
+class VarDesc;
+class VarHandle;
+
 namespace details {
 
 class ComputationOpHandle;
@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector";
 const char kAllPlaces[] = "all_places";
 
 using LastLiveOpsOfVars =
-    std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
+    std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
 const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
 
+VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
+
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc b/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc
new file mode 100644
index 0000000000..fd6b6dd227
--- /dev/null
+++ b/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc
@@ -0,0 +1,62 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/details/computation_op_handle.h"
+#include "paddle/fluid/framework/details/multi_devices_helper.h"
+#include "paddle/fluid/framework/ir/graph_helper.h"
+#include "paddle/fluid/operators/controlflow/while_op_helper.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+class WhileOpEagerDeletionPass : public ir::Pass {
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(
+      std::unique_ptr<ir::Graph> graph) const override {
+    auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
+
+    // Find all while_op and while_grad_op
+    std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
+                                         std::vector<OperatorBase *>>>
+        target_ops;
+    for (auto *op : all_ops) {
+      auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
+      if (compute_op == nullptr) continue;
+
+      if (compute_op->Name() == "while") {
+        target_ops[compute_op->GetScopeIdx()].first.emplace_back(
+            compute_op->GetOp());
+      } else if (compute_op->Name() == "while_grad") {
+        target_ops[compute_op->GetScopeIdx()].second.emplace_back(
+            compute_op->GetOp());
+      }
+    }
+
+    for (auto &ops_pair : target_ops) {
+      auto &while_ops = ops_pair.second.first;
+      auto &while_grad_ops = ops_pair.second.second;
+      operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
+          while_ops, while_grad_ops);
+    }
+    return graph;
+  }
+};
+
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(while_op_eager_deletion_pass,
+              paddle::framework::details::WhileOpEagerDeletionPass);
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index c31d0beec3..f3869ceb6d 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -14,6 +14,10 @@ limitations under the License. */
 
 #include "paddle/fluid/framework/executor.h"
 #include <deque>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
 
 #include "paddle/fluid/framework/feed_fetch_method.h"
 #include "paddle/fluid/framework/lod_rank_table.h"
@@ -23,6 +27,7 @@ limitations under the License. */
 #include "paddle/fluid/framework/threadpool.h"
 #include "paddle/fluid/framework/transfer_scope_cache.h"
 #include "paddle/fluid/framework/variable_helper.h"
+#include "paddle/fluid/operators/controlflow/while_op_helper.h"
 #include "paddle/fluid/operators/distributed/distributed.h"
 #include "paddle/fluid/platform/place.h"
 #include "paddle/fluid/platform/profiler.h"
@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
 
 ExecutorPrepareContext::ExecutorPrepareContext(
     const framework::ProgramDesc& prog, size_t block_id,
-    const std::vector<std::string>& skip_ref_cnt_vars)
-    : prog_(prog), block_id_(block_id) {
-  if (GetEagerDeletionThreshold() >= 0) {
-    global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
-                                                        skip_ref_cnt_vars);
+    const std::vector<std::string>& keep_vars, bool force_disable_gc)
+    : prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
+  if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) {
+    global_ref_cnts_ =
+        GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars);
   }
 }
 
@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
 }
 
 void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
-                   bool create_local_scope, bool create_vars) {
+                   bool create_local_scope, bool create_vars,
+                   const std::vector<std::string>& skip_ref_cnt_vars,
+                   bool force_disable_gc) {
   platform::RecordBlock b(block_id);
   if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
 #ifdef PADDLE_WITH_NGRAPH
   if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
 #endif
-  auto ctx = Prepare(pdesc, block_id);
+  auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
   RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
 }
 
@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
 
 std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
     const ProgramDesc& program, int block_id,
-    const std::vector<std::string>& skip_ref_cnt_vars) {
-  std::unique_ptr<ExecutorPrepareContext> ctx(
-      new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars));
+    const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
+  std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext(
+      program, block_id, skip_ref_cnt_vars, force_disable_gc));
   PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
   auto& block = program.Block(block_id);
   for (auto& op_desc : block.AllOps()) {
@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
 
 std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
     const ProgramDesc& program, const std::vector<int>& block_ids,
-    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
+    const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
+    bool force_disable_gc) {
   PADDLE_ENFORCE(
       skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
       "skip_ref_cnt_vars should be either empty or equals to block number %d",
@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
   for (auto& bid : block_ids) {
     ExecutorPrepareContext* ctx;
     if (skip_ref_cnt_vars.empty()) {
-      ctx = new ExecutorPrepareContext(program, bid);
+      ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
+                                       force_disable_gc);
     } else {
-      ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]);
+      ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
+                                       force_disable_gc);
     }
     PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
     auto& block = program.Block(bid);
@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
 
   int64_t max_memory_size = GetEagerDeletionThreshold();
   std::unique_ptr<GarbageCollector> gc;
-  // skip while_op and while_grad_op temporarily
-  if (max_memory_size >= 0 && !keep_kids) {
+  // FIXME(zjl): recurrent_op is rather complex, we would
+  // disable gc forcely in recurrent_op
+  if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
     ctx->ResetReferenceCount();
 #ifdef PADDLE_WITH_CUDA
     if (platform::is_gpu_place(place_)) {
@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
 #ifdef PADDLE_WITH_CUDA
     }
 #endif
+    // If gc is enabled and block size > 1
+    if (gc && ctx->prog_.Size() > 1) {
+      operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
+                                                                 ctx->ops_);
+    }
   }
 
   for (auto& op : ctx->ops_) {
diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h
index 5a040ac641..65cb9e51ab 100644
--- a/paddle/fluid/framework/executor.h
+++ b/paddle/fluid/framework/executor.h
@@ -15,7 +15,9 @@ limitations under the License. */
 #pragma once
 
 #include <map>
+#include <memory>
 #include <string>
+#include <unordered_map>
 #include <vector>
 #include "paddle/fluid/framework/garbage_collector.h"
 #include "paddle/fluid/framework/op_info.h"
@@ -30,7 +32,8 @@ namespace framework {
 struct ExecutorPrepareContext {
   ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
                          const std::vector<std::string>& skip_ref_cnt_vars =
-                             std::vector<std::string>());
+                             std::vector<std::string>(),
+                         bool force_disable_gc = false);
 
   ~ExecutorPrepareContext();
 
@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
 
   const framework::ProgramDesc& prog_;
   size_t block_id_;
+  bool force_disable_gc_;
   std::vector<std::unique_ptr<OperatorBase>> ops_;
 
   std::unordered_map<std::string, size_t> global_ref_cnts_;
@@ -66,7 +70,10 @@ class Executor {
    *  Scope
    */
   void Run(const ProgramDesc& prog, Scope* scope, int block_id,
-           bool create_local_scope = true, bool create_vars = true);
+           bool create_local_scope = true, bool create_vars = true,
+           const std::vector<std::string>& skip_ref_cnt_vars =
+               std::vector<std::string>(),
+           bool force_disable_gc = false);
 
   // This API is very slow.
   void Run(const ProgramDesc& program, Scope* scope,
@@ -79,12 +86,14 @@ class Executor {
   static std::unique_ptr<ExecutorPrepareContext> Prepare(
       const ProgramDesc& program, int block_id,
       const std::vector<std::string>& skip_ref_cnt_vars =
-          std::vector<std::string>());
+          std::vector<std::string>(),
+      bool force_disable_gc = false);
 
   static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
       const ProgramDesc& program, const std::vector<int>& block_ids,
       const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
-          std::vector<std::vector<std::string>>());
+          std::vector<std::vector<std::string>>(),
+      bool force_disable_gc = false);
 
   void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
 
diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc
index 012dfc1c7f..5530823b90 100644
--- a/paddle/fluid/imperative/layer.cc
+++ b/paddle/fluid/imperative/layer.cc
@@ -159,10 +159,9 @@ class Autograd {
       for (auto it : candidate->pre_ops_) {
         for (OpBase* pre_op : it.second) {
           if (!pre_op) continue;
-          VLOG(5) << "op dep " << candidate->op_desc_->Type() << " trace id "
+          VLOG(5) << "op dep " << candidate->Type() << " trace id "
                   << candidate->trace_id_ << " <---- " << it.first << " <---- "
-                  << pre_op->op_desc_->Type() << " trace id "
-                  << pre_op->trace_id_;
+                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
           if (visited.find(pre_op) == visited.end()) {
             visited.insert(pre_op);
             queue.push_back(pre_op);
@@ -180,10 +179,12 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
   PADDLE_ENFORCE(var_->IsInitialized(),
                  "Variable must be initialized when getting numpy tensor");
 
-  std::unique_ptr<VarBase> new_var(new VarBase());
+  // TODO(minqiyang): change this after move unique_name generator to CXX
+  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
+  std::unique_ptr<VarBase> new_var(new VarBase(
+      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
   framework::LoDTensor* tensor =
       new_var->var_->GetMutable<framework::LoDTensor>();
-  tensor->Resize(var_->Get<framework::LoDTensor>().dims());
   tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
 
   if (blocking) {
@@ -199,52 +200,62 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
   }
 
   if (platform::is_gpu_place(dst_place)) {
-    VLOG(3) << "copy tensor " << var_desc_->Name() << " from gpu";
+    VLOG(3) << "copy tensor " << Name() << " from gpu";
   }
 
   return new_var;
 }
 
 framework::LoDTensor& VarBase::GradValue() {
-  VLOG(3) << "get var grad " << var_desc_->Name();
+  VLOG(3) << "get var grad " << Name();
+  PADDLE_ENFORCE_NOT_NULL(grads_,
+                          "Could not get grad value from no grad variable");
   return *(grads_->var_->GetMutable<framework::LoDTensor>());
 }
 
 std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
   if (grad_op_descs_.empty() && backward_id_ <= 0) {
-    VLOG(3) << "op with no grad: " << op_desc_->Type();
+    VLOG(3) << "op with no grad: " << Type();
     return {};
   }
 
-  VLOG(3) << "apply op grad: " << op_desc_->Type();
-  std::vector<framework::VariableValueMap> grad_outputs;
+  VLOG(3) << "apply op grad: " << Type();
+  std::vector<framework::VariableValueMap> tmp_grad_outputs;
   if (backward_id_ > 0) {
     VLOG(3) << "py_layer_grad";
-    grad_outputs.resize(1);
-    grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
+    tmp_grad_outputs.resize(1);
+    tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
         PyLayer::ApplyGrad(
             backward_id_,
             grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
   } else {
-    grad_outputs.resize(grad_op_descs_.size());
-    for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
+    const size_t grad_op_count = grad_op_descs_.size();
+
+    tmp_grad_outputs.resize(grad_op_count);
+    for (size_t k = 0; k < grad_op_count; ++k) {
       framework::OpDesc* grad_op_desc = grad_op_descs_[k];
-      VLOG(3) << "op grad " << grad_op_desc->Type();
-      for (auto it : grad_output_vars_[k]) {
-        auto& outputs = grad_outputs[k][it.first];
+      auto& grad_output_variable_map = grad_output_vars_[k];
+
+      VLOG(3) << "apply grad op " << grad_op_desc->Type();
+
+      // Allocate tmp grad output variable
+      for (auto it : grad_output_variable_map) {
+        auto& outputs = tmp_grad_outputs[k][it.first];
+        outputs.reserve(it.second.size());
         for (size_t i = 0; i < it.second.size(); ++i) {
           // Allocate a new variable
           Variable* tmp_var = new framework::Variable();
           tmp_var->GetMutable<framework::LoDTensor>();
-          outputs.push_back(tmp_var);
+          outputs.emplace_back(tmp_var);
         }
       }
 
-      framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
+      // Run grad op
+      framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
 
       // No need to do compile time infer shape here.
       // grad_op_desc_->InferShape(*block_);
-      grad_op_desc->InferVarType(block_);
+      // grad_op_desc->InferVarType(block_);
 
       std::unique_ptr<framework::OperatorBase> opbase =
           framework::OpRegistry::CreateOp(*grad_op_desc);
@@ -260,9 +271,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
     }
   }
 
+  // Add tmp grad outputs to original grad vars
   for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
     for (auto it : grad_output_vars_[k]) {
-      auto& outputs = grad_outputs[k][it.first];
+      auto& outputs = tmp_grad_outputs[k][it.first];
       auto& origin_outputs = it.second;
       PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
 
@@ -316,19 +328,14 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
 
 int PyLayer::NumFuncs() { return py_funcs_.size(); }
 
-std::vector<VarBase*> PyLayer::Apply(int func_id,
-                                     const std::vector<VarBase*>& inputs) {
+std::vector<Variable*> PyLayer::Apply(int func_id,
+                                      const std::vector<VarBase*>& inputs) {
   std::vector<framework::Variable*> invars;
   for (const VarBase* in : inputs) {
     invars.push_back(in->var_);
   }
   PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
-  std::vector<Variable*> outvars = CallPythonFunc(py_funcs_[func_id], invars);
-  std::vector<VarBase*> ret;
-  for (Variable* v : outvars) {
-    ret.push_back(new VarBase(v, new VarBase(true)));
-  }
-  return ret;
+  return CallPythonFunc(py_funcs_[func_id], invars);
 }
 
 std::vector<Variable*> PyLayer::ApplyGrad(
diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h
index 7a9f33dc1e..618a5b7a03 100644
--- a/paddle/fluid/imperative/layer.h
+++ b/paddle/fluid/imperative/layer.h
@@ -112,31 +112,53 @@ class OpBase;
  */
 class VarBase {
  public:
-  VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {}
-
-  explicit VarBase(bool stop_gradient)
-      : VarBase(new framework::Variable(),
-                stop_gradient ? nullptr : new VarBase(true), stop_gradient) {}
-
-  VarBase(framework::Variable* var, VarBase* grad)
-      : VarBase(var, grad, false) {}
+  // Internal interface, create VarBase from exist variable
+  VarBase(const std::string& name, framework::Variable* var, VarBase* grad,
+          bool stop_gradient)
+      : VarBase(name, var->Get<framework::LoDTensor>().type(),
+                var->Get<framework::LoDTensor>().dims(),
+                var->Get<framework::LoDTensor>().place(), var, grad,
+                stop_gradient, false) {}
+
+  // Python interface
+  VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
+          const std::vector<int64_t>& shape, const platform::Place& place,
+          bool stop_gradient, bool persistable)
+      : VarBase(name, dtype, framework::make_ddim(shape), place, stop_gradient,
+                persistable) {}
+
+  // Internal interface, create VarBase from with ddim
+  VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
+          const framework::DDim& shape, const platform::Place& place,
+          bool stop_gradient, bool persistable)
+      : VarBase(name, dtype, shape, place, nullptr, nullptr, stop_gradient,
+                persistable) {}
 
  private:
-  VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
-      : name_(),
-        var_desc_(nullptr),
+  VarBase(const std::string& name, framework::proto::VarType::Type dtype,
+          const framework::DDim& shape, const platform::Place& place,
+          framework::Variable* var, VarBase* grad, bool stop_gradient,
+          bool persistable)
+      : name_(name),
+        dtype_(dtype),
+        place_(place),
         var_(var),
         grads_(grad),
-        block_(nullptr),
-        persistable_(false),
         stop_gradient_(stop_gradient),
+        persistable_(persistable),
         pre_op_(nullptr),
         pre_op_out_name_(),
-        pre_op_out_idx_(-1) {}
+        pre_op_out_idx_(-1) {
+    if (!var_) {
+      var_ = new framework::Variable();
+      auto tensor = var_->GetMutable<framework::LoDTensor>();
+      tensor->Resize(shape);
+      tensor->mutable_data(place_, dtype_);
+    }
+  }
 
  public:
   virtual ~VarBase() {
-    // TODO(minqiyang): remove var desc from block desc
     if (var_) {
       delete var_;
       var_ = nullptr;
@@ -151,14 +173,30 @@ class VarBase {
     pre_op_out_idx_ = -1;
   }
 
-  inline OpBase* PreOp() const { return pre_op_; }
-  inline int PreOpOutIdx() const { return pre_op_out_idx_; }
+  inline void SetName(const std::string& name) { name_ = name; }
+  inline std::string Name() const { return name_; }
+
+  inline std::vector<int64_t> Shape() const {
+    if (var_->IsInitialized()) {
+      return framework::vectorize(var_->Get<framework::LoDTensor>().dims());
+    } else {
+      return {};
+    }
+  }
+
+  inline framework::proto::VarType::Type DType() const { return dtype_; }
 
   inline void SetStopGradient(bool stop_gradient) {
     stop_gradient_ = stop_gradient;
   }
   inline bool IsStopGradient() const { return stop_gradient_; }
 
+  inline void SetPersistable(bool persistable) { persistable_ = persistable; }
+  inline bool IsPersistable() const { return persistable_; }
+
+  inline OpBase* PreOp() const { return pre_op_; }
+  inline int PreOpOutIdx() const { return pre_op_out_idx_; }
+
   void RunBackward();
 
   inline void ResetPreOp(OpBase* op) {
@@ -180,7 +218,7 @@ class VarBase {
   }
 
   void ClearGradient() {
-    VLOG(1) << "clear gradient of " << var_desc_->Name();
+    VLOG(1) << "clear gradient of " << Name();
     if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
       auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
       operators::math::set_constant(
@@ -196,23 +234,20 @@ class VarBase {
                                       const bool blocking) const;
 
   inline std::string GradName() const {
-    PADDLE_ENFORCE(
-        var_desc_,
-        "Couldn't get gradient variable's name, please call backward() first");
-    return string::Sprintf("%s@IGrad", var_desc_->Name());
+    return string::Sprintf("%s@IGrad", Name());
   }
 
   std::string name_;
-  framework::VarDesc* var_desc_;
+  framework::proto::VarType::Type dtype_;
+  platform::Place place_;
 
   framework::Variable* var_;
   VarBase* grads_;
 
-  framework::BlockDesc* block_;
-  bool persistable_;
-
  private:
   bool stop_gradient_;
+  bool persistable_;
+
   OpBase* pre_op_;
   std::string pre_op_out_name_;
   int pre_op_out_idx_;
@@ -223,11 +258,11 @@ class VarBase {
  */
 class PYBIND11_HIDDEN OpBase {
  public:
-  OpBase()
-      : op_desc_(nullptr),
+  OpBase(const std::string& type)
+      : type_(type),
+        trace_id_(-1),
         forward_id_(-1),
         backward_id_(-1),
-        trace_id_(-1),
         place_(platform::CPUPlace()),
         backward_hooks_() {}
 
@@ -249,13 +284,34 @@ class PYBIND11_HIDDEN OpBase {
 
   std::map<std::string, std::vector<VarBase*>> ApplyGrad();
 
+  inline std::string Type() const { return type_; }
+  inline std::string GradOpType(size_t index) const {
+    PADDLE_ENFORCE_NOT_NULL(grad_op_descs_[index]);
+    return grad_op_descs_[index]->Type();
+  }
+
   void RegisterBackwardHooks(const py::object& callable);
 
   void InvokeBackwardHooks();
 
-  // One of `op_desc_` or `forward_id_` is set, not both.
-  // For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
-  framework::OpDesc* op_desc_;
+  void TrackPreOp(const VarBase* inp_var, const std::string& inp_name) {
+    if (inp_var->PreOp() && !inp_var->IsStopGradient()) {
+      VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot "
+              << inp_name;
+      pre_ops_[inp_name].push_back(inp_var->PreOp());
+      pre_ops_out_idx_[inp_name].push_back(inp_var->PreOpOutIdx());
+    } else {
+      VLOG(3) << "no pre op in slot " << inp_name
+              << " input var stop_gradient: " << inp_var->IsStopGradient();
+      pre_ops_[inp_name].push_back(nullptr);
+      // pre_ops_out_idx_[inp_name].push_back(-1);
+    }
+  }
+
+  std::string type_;
+  // One of `trace_id_` or `forward_id_` is set, not both.
+  // For pure python PyLayer, use `forward_id_`, otherwise, use trace_id_.
+  int trace_id_;
   int forward_id_;
 
   // When has backward, one of `grad_op_descs_` or `backward_id_` is set,
@@ -263,7 +319,6 @@ class PYBIND11_HIDDEN OpBase {
   // Note: each fwd op corresponds to a vector of bwd ops.
   std::vector<framework::OpDesc*> grad_op_descs_;
   int backward_id_;
-  int trace_id_;
 
   platform::Place place_;
 
@@ -277,8 +332,6 @@ class PYBIND11_HIDDEN OpBase {
   // Outputs to a vector of bwd ops.
   std::vector<framework::VariableValueMap> grad_output_vars_;
 
-  framework::BlockDesc* block_;
-
   std::vector<py::object> backward_hooks_;
 };
 
@@ -303,8 +356,8 @@ class PyLayer {
 
   static int NumFuncs();
 
-  static std::vector<VarBase*> Apply(int func_id,
-                                     const std::vector<VarBase*>& inputs);
+  static std::vector<framework::Variable*> Apply(
+      int func_id, const std::vector<VarBase*>& inputs);
 
   static std::vector<framework::Variable*> ApplyGrad(
       int func_id, const std::vector<framework::Variable*>& inputs);
diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc
index 0cb1676372..7ee92b4d8c 100644
--- a/paddle/fluid/imperative/tracer.cc
+++ b/paddle/fluid/imperative/tracer.cc
@@ -56,15 +56,19 @@ void CreateGradOp(const framework::OpDesc& op_desc,
   }
 }
 
-void InitVar(framework::Variable* var, framework::Variable* grad_var,
-             platform::DeviceContext* dev_ctx) {
+void InitGrad(VarBase* var, platform::DeviceContext* dev_ctx) {
+  PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
   PADDLE_ENFORCE_NOT_NULL(dev_ctx,
                           "Could not get valid device from forward op");
-  auto& var_t = var->Get<framework::LoDTensor>();
-  grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>(
-      var_t.dims(), dev_ctx->GetPlace());
-  operators::math::set_constant(
-      *dev_ctx, grad_var->GetMutable<framework::LoDTensor>(), 0.0);
+
+  if (var->grads_ == nullptr) {
+    auto& var_t = var->var_->Get<framework::LoDTensor>();
+    var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32,
+                              framework::vectorize(var_t.dims()),
+                              dev_ctx->GetPlace(), true, false);
+    auto grad_t = var->grads_->var_->GetMutable<framework::LoDTensor>();
+    operators::math::set_constant(*dev_ctx, grad_t, 0.0);
+  }
 }
 
 platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
@@ -85,6 +89,62 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
   return result;
 }
 
+framework::VariableNameMap CreateInputVarNameMap(
+    const OpBase* op, const VarBasePtrMap& varbase_map) {
+  framework::VariableNameMap result;
+
+  auto& info_map = framework::OpInfoMap::Instance();
+  auto* op_info = info_map.GetNullable(op->Type());
+  if (op_info == nullptr || op_info->proto_ == nullptr) {
+    return result;
+  }
+
+  for (auto& in : op_info->Proto().inputs()) {
+    auto it = varbase_map.find(in.name());
+    if (it == varbase_map.end()) {
+      PADDLE_ENFORCE(in.dispensable());
+      result[in.name()] = {};
+    } else {
+      auto var_vector = it->second;
+      std::vector<std::string> args;
+      args.reserve(var_vector.size());
+      for (VarBase* var_base : var_vector) {
+        args.emplace_back(var_base->Name());
+      }
+      result[in.name()] = args;
+    }
+  }
+  return result;
+}
+
+framework::VariableNameMap CreateOutputVarNameMap(
+    const OpBase* op, const VarBasePtrMap& varbase_map) {
+  framework::VariableNameMap result;
+
+  auto& info_map = framework::OpInfoMap::Instance();
+  auto* op_info = info_map.GetNullable(op->Type());
+  if (op_info == nullptr || op_info->proto_ == nullptr) {
+    return result;
+  }
+
+  for (auto& out : op_info->Proto().outputs()) {
+    auto it = varbase_map.find(out.name());
+    if (it == varbase_map.end()) {
+      PADDLE_ENFORCE(out.dispensable());
+      result[out.name()] = {};
+    } else {
+      auto var_vector = it->second;
+      std::vector<std::string> args;
+      args.reserve(var_vector.size());
+      for (VarBase* var_base : var_vector) {
+        args.emplace_back(var_base->Name());
+      }
+      result[out.name()] = args;
+    }
+  }
+  return result;
+}
+
 Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
   if (!FLAGS_tracer_profile_fname.empty()) {
     std::call_once(gTracerProfileOnce, [] {
@@ -101,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
 
 std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
                                     const VarBasePtrMap& outputs,
-                                    framework::BlockDesc* block,
+                                    framework::AttributeMap attrs_map,
                                     const platform::Place expected_place,
                                     const bool stop_gradient) {
 #ifdef WITH_GPERFTOOLS
@@ -110,40 +170,27 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
   }
 #endif
 
-  std::map<std::string, VarBase*> vars;
-
-  framework::OpDesc* op_desc = op->op_desc_;
-  VLOG(3) << "tracer tracing " << op_desc->Type() << " trace id "
-          << op->trace_id_;
-  op_desc->InferShape(*block);
-  op_desc->InferVarType(block);
-
-  std::unique_ptr<framework::OperatorBase> op_base =
-      framework::OpRegistry::CreateOp(*op_desc);
-
   framework::VariableValueMap invars_map;
   framework::VariableValueMap outvars_map;
 
+  // Construct input_vars_map and output_vars_map
+  std::map<std::string, VarBase*> current_vars_map;
   op->input_vars_ = inputs;
   for (auto it : op->input_vars_) {
     auto& invars = invars_map[it.first];
     invars.reserve(it.second.size());
     for (VarBase* inp : it.second) {
-      PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
-                              op->op_desc_->Type(), inp->var_desc_->Name());
+      PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(),
+                              inp->Name());
 
       invars.emplace_back(inp->var_);
-      vars[inp->var_desc_->Name()] = inp;
-      if (inp->PreOp() && !inp->IsStopGradient()) {
-        op->pre_ops_[it.first].push_back(inp->PreOp());
-        op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
-        VLOG(3) << "add pre op " << inp->PreOp()->op_desc_->Type();
-      } else {
-        op->pre_ops_[it.first].push_back(nullptr);
+      op->TrackPreOp(inp, it.first);
+      if (!stop_gradient) {
+        current_vars_map[inp->Name()] = inp;
       }
-      VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
-              << inp->var_->IsInitialized() << " stop_gradient "
-              << inp->IsStopGradient();
+      VLOG(3) << "input var name: " << inp->Name()
+              << " inited: " << inp->var_->IsInitialized()
+              << " stop_grad: " << inp->IsStopGradient();
     }
   }
 
@@ -152,25 +199,38 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
     auto& outvars = outvars_map[it.first];
     const std::vector<VarBase*>& outputs = it.second;
     outvars.reserve(outputs.size());
-    for (size_t i = 0; i < outputs.size(); ++i) {
+    for (size_t i = 0U; i < outputs.size(); ++i) {
       VarBase* out = outputs[i];
       outvars.emplace_back(out->var_);
-      vars[out->var_desc_->Name()] = out;
-
-      framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
-      if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
-        out->var_->GetMutable<framework::LoDTensor>();
-      } else {
-        LOG(ERROR) << "tracer doesn't support yet";
-      }
       out->TrackPreOp(op, it.first, i, stop_gradient);
+      if (!stop_gradient) {
+        current_vars_map[out->Name()] = out;
+      }
 
-      VLOG(3) << "output vname " << out->var_desc_->Name() << " "
-              << out->var_->IsInitialized();
+      VLOG(3) << "input var name: " << out->Name()
+              << " inited: " << out->var_->IsInitialized()
+              << " stop_grad: " << out->IsStopGradient();
     }
   }
 
-  VLOG(3) << "tracer running " << op_desc->Type();
+  // Check attrs and create op
+  framework::VariableNameMap invars_name_map =
+      CreateInputVarNameMap(op, inputs);
+  framework::VariableNameMap outvars_name_map =
+      CreateOutputVarNameMap(op, outputs);
+
+  auto& info = framework::OpInfoMap::Instance().Get(op->Type());
+  if (info.Checker() != nullptr) {
+    info.Checker()->Check(&attrs_map);
+  }
+
+  std::unique_ptr<framework::OperatorBase> op_base =
+      framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
+                                      outvars_name_map, attrs_map);
+
+  // TODO(minqiyang): Support infer var type in imperative mode
+  // Run forward op
+  VLOG(3) << "tracer running " << op->Type();
   framework::RuntimeContext ctx(invars_map, outvars_map);
 
   // TODO(panyx0718): Cache p.
@@ -186,36 +246,44 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
       framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
                                   prepared_op.ctx, prepared_op.kernel_configs));
 
+  // construct backward op
   std::set<std::string> vars_saved_for_backward;
-
   if (!stop_gradient) {
+    VLOG(5) << "start construct backward op";
+
+    // construct grad op descs
+    std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
+        op->Type(), invars_name_map, outvars_name_map, attrs_map));
     std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
         new std::unordered_map<std::string, std::string>());
-    CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get());
+    // NOTE(minqiyang): We don't support control flow op in imperative now
+    // Add grad_block_ when we want to support it
+    CreateGradOp(*fwd_op_desc, {}, {}, &op->grad_op_descs_, grad_to_var.get());
 
-    op->grad_input_vars_.resize(op->grad_op_descs_.size());
-    op->grad_output_vars_.resize(op->grad_op_descs_.size());
+    VLOG(5) << "create grad op desc: " << op->grad_op_descs_[0]->Type();
 
-    for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
+    const size_t grad_op_count = op->grad_op_descs_.size();
+
+    op->grad_input_vars_.resize(grad_op_count);
+    op->grad_output_vars_.resize(grad_op_count);
+
+    for (size_t i = 0; i < grad_op_count; ++i) {
       framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
       for (auto it : grad_op_desc->Inputs()) {
         auto& grad_in_vars = op->grad_input_vars_[i][it.first];
+        grad_in_vars.reserve(it.second.size());
         for (const std::string& grad_invar : it.second) {
-          block->FindRecursiveOrCreateVar(grad_invar);
           auto var_it = grad_to_var->find(grad_invar);
           if (var_it == grad_to_var->end()) {
-            auto fwd_var_it = vars.find(grad_invar);
-            PADDLE_ENFORCE(fwd_var_it != vars.end());
+            auto fwd_var_it = current_vars_map.find(grad_invar);
+            PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
             // Forward inputs or outputs.
-            grad_in_vars.push_back(fwd_var_it->second->var_);
+            grad_in_vars.emplace_back(fwd_var_it->second->var_);
           } else {
-            VarBase* var = vars[var_it->second];
-            if (!var->grads_->var_->IsInitialized()) {
-              InitVar(var->var_, var->grads_->var_,
-                      prepared_op.GetDeviceContext());
-            }
+            VarBase* var = current_vars_map[var_it->second];
+            InitGrad(var, prepared_op.GetDeviceContext());
             // Douts.
-            grad_in_vars.push_back(var->grads_->var_);
+            grad_in_vars.emplace_back(var->grads_->var_);
           }
 
           vars_saved_for_backward.insert(it.first);
@@ -225,48 +293,48 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
       for (auto it : grad_op_desc->Outputs()) {
         auto& grad_out_vars = op->grad_output_vars_[i][it.first];
         for (const std::string& grad_outvar : it.second) {
-          block->FindRecursiveOrCreateVar(grad_outvar);
           auto var_it = grad_to_var->find(grad_outvar);
           PADDLE_ENFORCE(var_it != grad_to_var->end(),
                          "Could not found the grad op output var, should this "
                          "operator %s's stop gradient be True",
-                         op_desc->Type());
-          VarBase* var = vars[var_it->second];
-          if (!var->grads_->var_->IsInitialized()) {
-            InitVar(var->var_, var->grads_->var_,
-                    prepared_op.GetDeviceContext());
-          }
+                         op->Type());
+          VarBase* var = current_vars_map[var_it->second];
+          InitGrad(var, prepared_op.GetDeviceContext());
           grad_out_vars.push_back(var->grads_->var_);
         }
       }
     }
   }
 
-  op->block_ = block;
   return vars_saved_for_backward;
 }
 
 std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
                                       const std::vector<VarBase*>& inputs,
                                       bool stop_gradient) {
-  VLOG(3) << "py_trace";
+  VLOG(3) << "py_trace " << op->Type();
+
   op->input_vars_[PyLayer::kFwdInp] = inputs;
-  op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
+
+  std::vector<framework::Variable*> ret_vars =
+      PyLayer::Apply(op->forward_id_, inputs);
+
   for (VarBase* inp : inputs) {
-    if (inp->PreOp() && !inp->IsStopGradient()) {
-      op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp());
-      op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx());
-    } else {
-      op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr);
-    }
+    op->TrackPreOp(inp, PyLayer::kFwdInp);
   }
 
-  auto& outputs = op->output_vars_[PyLayer::kFwdOut];
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    VarBase* out = outputs[i];
+  std::vector<VarBase*>& outputs = op->output_vars_[PyLayer::kFwdOut];
+  outputs.reserve(ret_vars.size());
+  for (size_t i = 0U; i != ret_vars.size(); ++i) {
+    framework::Variable* v = ret_vars[i];
+    VarBase* out = new VarBase(string::Sprintf("%s_out_%d", op->Type(), i), v,
+                               nullptr, stop_gradient);
+    outputs.emplace_back(out);
     out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
   }
+
   if (!stop_gradient) {
+    VLOG(5) << "start construct backward op";
     op->grad_input_vars_.resize(1);
     op->grad_output_vars_.resize(1);
     auto& grad_input_vars =
@@ -281,23 +349,16 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
       grad_input_vars.push_back(out->var_);
     }
 
+    // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
     platform::CPUPlace place;
     for (VarBase* out : outputs) {
+      InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
       grad_input_vars.push_back(out->grads_->var_);
-      if (!grad_input_vars.back()->IsInitialized()) {
-        // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
-        InitVar(out->var_, grad_input_vars.back(),
-                platform::DeviceContextPool::Instance().Get(place));
-      }
     }
 
-    for (const VarBase* inp : inputs) {
+    for (VarBase* inp : inputs) {
+      InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
       grad_output_vars.push_back(inp->grads_->var_);
-      if (!grad_output_vars.back()->IsInitialized()) {
-        // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
-        InitVar(inp->var_, grad_output_vars.back(),
-                platform::DeviceContextPool::Instance().Get(place));
-      }
     }
   }
   return outputs;
diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h
index 8a0267c37f..7b65d55e9e 100644
--- a/paddle/fluid/imperative/tracer.h
+++ b/paddle/fluid/imperative/tracer.h
@@ -17,6 +17,8 @@
 #include <map>
 #include <set>
 #include <string>
+#include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 #include "paddle/fluid/framework/op_desc.h"
@@ -34,7 +36,8 @@ void CreateGradOp(const framework::OpDesc& op_desc,
                   framework::OpDesc** grad_op_desc,
                   std::unordered_map<std::string, std::string>* grad_to_var);
 
-void InitVar(framework::Variable* var, framework::Variable* grad_var);
+void InitVar(const VarBase* var, framework::Variable* grad_var,
+             platform::DeviceContext* dev_ctx);
 
 platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs);
 
@@ -46,7 +49,7 @@ class Tracer {
 
   std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
                               const VarBasePtrMap& outputs,
-                              framework::BlockDesc* block,
+                              framework::AttributeMap attrs_map,
                               const platform::Place expected_place,
                               const bool stop_gradient = false);
 
diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc
index cf02901d96..9a40cf4b60 100644
--- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc
+++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc
@@ -126,15 +126,20 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
 }
 template void ZeroCopyTensor::copy_from_cpu<float>(const float *data);
 template void ZeroCopyTensor::copy_from_cpu<int64_t>(const int64_t *data);
+template void ZeroCopyTensor::copy_from_cpu<int32_t>(const int32_t *data);
 template void ZeroCopyTensor::copy_to_cpu<float>(float *data);
 template void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
+template void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
 
 template float *ZeroCopyTensor::data<float>(PaddlePlace *place,
                                             int *size) const;
 template int64_t *ZeroCopyTensor::data<int64_t>(PaddlePlace *place,
                                                 int *size) const;
+template int32_t *ZeroCopyTensor::data<int32_t>(PaddlePlace *place,
+                                                int *size) const;
 template float *ZeroCopyTensor::mutable_data<float>(PaddlePlace place);
 template int64_t *ZeroCopyTensor::mutable_data<int64_t>(PaddlePlace place);
+template int32_t *ZeroCopyTensor::mutable_data<int32_t>(PaddlePlace place);
 
 void *ZeroCopyTensor::FindTensor() const {
   PADDLE_ENFORCE(!name_.empty(),
diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h
index 1ce3fe5af7..258a79fa4e 100644
--- a/paddle/fluid/inference/api/helper.h
+++ b/paddle/fluid/inference/api/helper.h
@@ -139,9 +139,8 @@ static void TensorAssignData(PaddleTensor *tensor,
 }
 
 template <typename T>
-static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
-                                    const std::vector<std::vector<T>> &data) {
-  int size{0};
+static void ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
+                                     const std::vector<std::vector<T>> &data) {
   auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU);
   int c = 0;
   for (const auto &f : data) {
@@ -149,7 +148,15 @@ static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
       ptr[c++] = v;
     }
   }
-  return size;
+}
+
+template <typename T>
+static void ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
+                                     const PaddleBuf &data) {
+  auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU);
+  for (size_t i = 0; i < data.length() / sizeof(T); i++) {
+    ptr[i] = *(reinterpret_cast<T *>(data.data()) + i);
+  }
 }
 
 static bool CompareTensor(const PaddleTensor &a, const PaddleTensor &b) {
diff --git a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc
index 3f6c933f2b..5157bd280d 100644
--- a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc
@@ -107,6 +107,9 @@ void SetConfig(AnalysisConfig *cfg) {
   cfg->DisableGpu();
   cfg->SwitchSpecifyInputNames();
   cfg->SwitchIrOptim();
+  if (FLAGS_zero_copy) {
+    cfg->SwitchUseFeedFetchOps(false);
+  }
 }
 
 void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
@@ -131,7 +134,7 @@ TEST(Analyzer_Pyramid_DNN, profile) {
   TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
                  input_slots_all, &outputs, FLAGS_num_threads);
 
-  if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
+  if (FLAGS_num_threads == 1 && !FLAGS_test_all_data && !FLAGS_zero_copy) {
     PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
     size_t size = GetSize(outputs[0]);
     PADDLE_ENFORCE_GT(size, 0);
@@ -166,6 +169,19 @@ TEST(Analyzer_Pyramid_DNN, compare) {
       reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
 }
 
+// Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
+TEST(Analyzer_Pyramid_DNN, compare_zero_copy) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  std::vector<std::string> outputs_name;
+  outputs_name.emplace_back("cos_sim_2.tmp_0");
+  CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
+                             input_slots_all, outputs_name);
+}
+
 // Compare Deterministic result
 TEST(Analyzer_Pyramid_DNN, compare_determine) {
   AnalysisConfig cfg;
diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
index 36282b3efe..dcf4b38ce8 100644
--- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
@@ -207,6 +207,9 @@ void SetConfig(AnalysisConfig *cfg) {
   cfg->DisableGpu();
   cfg->SwitchSpecifyInputNames();
   cfg->SwitchIrOptim();
+  if (FLAGS_zero_copy) {
+    cfg->SwitchUseFeedFetchOps(false);
+  }
 }
 
 void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
@@ -285,133 +288,17 @@ TEST(Analyzer_rnn1, multi_thread) {
                  input_slots_all, &outputs, 2 /* multi_thread */);
 }
 
-// Validate that the AnalysisPredictor + ZeroCopyTensor really works by testing
-// on the complex RNN1 model.
-TEST(Analyzer_rnn1, ZeroCopy) {
-  AnalysisConfig config;
-  SetConfig(&config);
-  config.SwitchUseFeedFetchOps(false);
-
-  PaddlePlace place;
-
-  auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
-
-  config.SwitchUseFeedFetchOps(true);
-  auto native_predictor =
-      CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
-
-  config.SwitchUseFeedFetchOps(
-      true);  // the analysis predictor needs feed/fetch.
-  auto analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
-
-#define NEW_TENSOR(name__) \
-  auto name__##_tensor = predictor->GetInputTensor(#name__);
-  NEW_TENSOR(data_lod_attention);
-  NEW_TENSOR(cell_init);
-  NEW_TENSOR(data);
-  NEW_TENSOR(week);
-  NEW_TENSOR(minute);
-  NEW_TENSOR(hidden_init);
-
-  // Prepare data for AnalysisPredictor
-  DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
-  PrepareZeroCopyInputs(data_lod_attention_tensor.get(), cell_init_tensor.get(),
-                        data_tensor.get(), hidden_init_tensor.get(),
-                        week_tensor.get(), minute_tensor.get(), &data,
-                        FLAGS_batch_size);
-
-  // Prepare data for NativePredictor
-  std::vector<std::vector<PaddleTensor>> native_inputs;
-  SetInput(&native_inputs);
-  std::vector<PaddleTensor> native_outputs;
-  std::vector<PaddleTensor> analysis_outputs;
-
-  auto output_tensor = predictor->GetOutputTensor("final_output.tmp_1");
-  // Run analysis predictor
-
-  int num_ops;
-  auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
-  ASSERT_TRUE(fuse_statis.count("fc_fuse"));
-  ASSERT_EQ(fuse_statis.at("fc_fuse"), 1);
-  ASSERT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2);  // bi-directional LSTM
-  ASSERT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
-  ASSERT_EQ(num_ops,
-            13);  // After graph optimization, only 13 operators exists.
-
-  Timer timer;
-  double total_time{0};
-  for (int i = 0; i < FLAGS_repeat; i++) {
-    timer.tic();
-    predictor->ZeroCopyRun();
-    total_time += timer.toc();
-  }
-  LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
-
-  ASSERT_TRUE(native_predictor->Run(native_inputs.front(), &native_outputs));
-  LOG(INFO) << "native output " << DescribeTensor(native_outputs.front());
-
-  int output_size{0};  // this is the number of elements not memory size
-  auto *zero_copy_data = output_tensor->data<float>(&place, &output_size);
-  auto *native_data = static_cast<float *>(native_outputs.front().data.data());
-  for (int i = 0; i < output_size; i++) {
-    EXPECT_NEAR(zero_copy_data[i], native_data[i], 1e-3);
-  }
-}
-
-TEST(Analyzer_rnn1, ZeroCopyMultiThread) {
-  AnalysisConfig config;
-  SetConfig(&config);
-  config.SwitchUseFeedFetchOps(false);
-
-#define NEW_TENSOR(name__) \
-  auto name__##_tensor = predictor->GetInputTensor(#name__);
-
-  std::vector<std::unique_ptr<PaddlePredictor>> predictors;
-  predictors.emplace_back(CreatePaddlePredictor<AnalysisConfig>(config));
-  for (int tid = 1; tid < FLAGS_num_threads; tid++) {
-    predictors.emplace_back(predictors.front()->Clone());
-  }
-  double total_time_of_threads{0};
-  std::vector<std::thread> threads;
-
-  for (int tid = 0; tid < FLAGS_num_threads; tid++) {
-    threads.emplace_back([&, tid] {
-      auto &predictor = predictors[tid];
-      NEW_TENSOR(data_lod_attention);
-      NEW_TENSOR(cell_init);
-      NEW_TENSOR(data);
-      NEW_TENSOR(week);
-      NEW_TENSOR(minute);
-      NEW_TENSOR(hidden_init);
-
-      // Prepare data for AnalysisPredictor
-      DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
-      Timer timer;
-      double total_time{0};
-
-      for (int i = 0; i < FLAGS_repeat; i++) {
-        PrepareZeroCopyInputs(data_lod_attention_tensor.get(),
-                              cell_init_tensor.get(), data_tensor.get(),
-                              hidden_init_tensor.get(), week_tensor.get(),
-                              minute_tensor.get(), &data, FLAGS_batch_size);
-
-        timer.tic();
-        predictor->ZeroCopyRun();
-        total_time += timer.toc();
-      }
-
-      total_time_of_threads += total_time;
-
-      LOG(INFO) << "thread time: " << total_time / FLAGS_repeat;
-    });
-  }
-
-  for (auto &t : threads) {
-    t.join();
-  }
+// Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
+TEST(Analyzer_rnn1, compare_zero_copy) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
 
-  LOG(INFO) << "average time: "
-            << total_time_of_threads / FLAGS_num_threads / FLAGS_repeat;
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  std::vector<std::string> outputs_name;
+  outputs_name.emplace_back("final_output.tmp_1");
+  CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
+                             input_slots_all, outputs_name);
 }
 
 }  // namespace inference
diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
index cca2ab1ee1..19fa5528da 100644
--- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
@@ -144,6 +144,9 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
   cfg->SwitchSpecifyInputNames();
   cfg->SwitchIrDebug();
   cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
+  if (FLAGS_zero_copy) {
+    cfg->SwitchUseFeedFetchOps(false);
+  }
   if (use_mkldnn) {
     cfg->EnableMKLDNN();
   }
@@ -184,10 +187,10 @@ TEST(Analyzer_seq_pool1, compare_determine) {
                        input_slots_all);
 }
 
-void analysis_fuse_statis(bool use_zerocopy) {
+// Check the fuse status
+TEST(Analyzer_seq_pool1, fuse_statis) {
   AnalysisConfig cfg;
   SetConfig(&cfg);
-  cfg.SwitchUseFeedFetchOps(!use_zerocopy);
   int num_ops;
   auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
   auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
@@ -203,137 +206,17 @@ void analysis_fuse_statis(bool use_zerocopy) {
   EXPECT_EQ(num_ops, 171);
 }
 
-// Check the fuse status
-TEST(Analyzer_seq_pool1, fuse_statis) { analysis_fuse_statis(false); }
-
-void PrepareZeroCopyInputs(
-    const std::unique_ptr<PaddlePredictor> &predictor,
-    std::vector<std::unique_ptr<ZeroCopyTensor>> *inputs) {
-  DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
-  // only feed one batch
-  const auto &one_batch = data.NextBatch();
-  inputs->clear();
-  for (size_t i = 0; i < one_batch.size(); ++i) {
-    auto &slot = one_batch[i];
-    auto tensor = predictor->GetInputTensor(slot.name + "_embed");
-    tensor->Reshape(slot.shape);
-    tensor->SetLoD({slot.lod});
-    ZeroCopyTensorAssignData<float>(tensor.get(), slot.data);
-    inputs->emplace_back(std::move(tensor));
-  }
-}
-
-// return the output values
-std::vector<float> zerocopy_profile(int repeat_times) {
-  AnalysisConfig config;
-  SetConfig(&config);
-  config.SwitchUseFeedFetchOps(false);
-  auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
-  std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
-  PrepareZeroCopyInputs(predictor, &inputs);
-  auto output_tensor = predictor->GetOutputTensor(out_var_name);
-  Timer timer;
-  LOG(INFO) << "Warm up run...";
-  timer.tic();
-  predictor->ZeroCopyRun();
-  PrintTime(FLAGS_batch_size, 1, 1, 0, timer.toc(), 1);
-  if (FLAGS_profile) {
-    paddle::platform::ResetProfiler();
-  }
-  LOG(INFO) << "Run " << repeat_times << " times...";
-  timer.tic();
-  for (int i = 0; i < repeat_times; i++) {
-    predictor->ZeroCopyRun();
-  }
-  PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times,
-            1);
-
-  LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
-  PaddlePlace place;
-  int output_size{0};
-  auto *pdata = output_tensor->data<float>(&place, &output_size);
-  std::vector<float> res(output_size);
-  for (int i = 0; i < output_size; ++i) {
-    res[i] = pdata[i];
-  }
-  return res;
-}
-
-TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); }
-
-TEST(Analyzer_seq_pool1, zerocopy_profile_threads) {
-  AnalysisConfig config;
-  SetConfig(&config);
-  config.SwitchUseFeedFetchOps(false);
-
-  std::vector<std::unique_ptr<PaddlePredictor>> predictors;
-  predictors.emplace_back(CreatePaddlePredictor<AnalysisConfig>(config));
-  for (int tid = 1; tid < FLAGS_num_threads; tid++) {
-    predictors.emplace_back(predictors.front()->Clone());
-  }
-  double total_time_of_threads{0};
-  std::vector<std::thread> threads;
-
-  for (int tid = 0; tid < FLAGS_num_threads; tid++) {
-    threads.emplace_back([&, tid] {
-      auto &predictor = predictors[tid];
-      std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
-      PrepareZeroCopyInputs(predictor, &inputs);
-      auto output_tensor = predictor->GetOutputTensor(out_var_name);
-      Timer timer;
-      double total_time{0};
-
-      LOG(INFO) << "Warm up run...";
-      timer.tic();
-      predictor->ZeroCopyRun();
-      PrintTime(FLAGS_batch_size, 1, FLAGS_num_threads, tid, timer.toc(), 1);
-      if (FLAGS_profile) {
-        paddle::platform::ResetProfiler();
-      }
-      int repeat_times = FLAGS_repeat;
-      LOG(INFO) << "Run " << repeat_times << " times...";
-      timer.tic();
-
-      for (int i = 0; i < repeat_times; i++) {
-        predictor->ZeroCopyRun();
-      }
-      total_time += timer.toc();
-      total_time_of_threads += total_time;
-
-      LOG(INFO) << "thread time: " << total_time / repeat_times;
-    });
-  }
-
-  for (auto &t : threads) {
-    t.join();
-  }
-
-  LOG(INFO) << "average time: "
-            << total_time_of_threads / FLAGS_num_threads / FLAGS_repeat;
-}
-
-TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { analysis_fuse_statis(true); }
+// Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
+TEST(Analyzer_seq_pool1, compare_zero_copy) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
 
-TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
-  AnalysisConfig config;
-  SetConfig(&config);
-  config.SwitchUseFeedFetchOps(true);
-  auto predictor = CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
-  std::vector<PaddleTensor> native_outputs;
   std::vector<std::vector<PaddleTensor>> input_slots_all;
   SetInput(&input_slots_all);
-  ASSERT_TRUE(predictor->Run(input_slots_all[0], &native_outputs));
-  EXPECT_EQ(native_outputs.size(), 1UL);
-
-  auto zerocopy_output = zerocopy_profile(1);
-  EXPECT_EQ(zerocopy_output.size() * sizeof(float),
-            native_outputs.front().data.length());
-  auto *native_data = static_cast<float *>(native_outputs.front().data.data());
-  for (size_t i = 0; i < zerocopy_output.size(); ++i) {
-    EXPECT_LT(
-        std::fabs((zerocopy_output[i] - native_data[i]) / zerocopy_output[i]),
-        1e-3);
-  }
+  std::vector<std::string> outputs_name;
+  outputs_name.emplace_back(out_var_name);
+  CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
+                             input_slots_all, outputs_name);
 }
 
 }  // namespace analysis
diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h
index 41daff83c4..a4881afe58 100644
--- a/paddle/fluid/inference/tests/api/tester_helper.h
+++ b/paddle/fluid/inference/tests/api/tester_helper.h
@@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true,
 DEFINE_bool(record_benchmark, false,
             "Record benchmark after profiling the model");
 DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
+DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch.");
 
 DECLARE_bool(profile);
 DECLARE_int32(paddle_num_threads);
@@ -67,6 +68,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
   LOG(INFO) << analysis_config->ToNativeConfig();
 }
 
+// Compare result between two PaddleTensor
 void CompareResult(const std::vector<PaddleTensor> &outputs,
                    const std::vector<PaddleTensor> &ref_outputs) {
   EXPECT_GT(outputs.size(), 0UL);
@@ -108,6 +110,50 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
   }
 }
 
+// Compare result between a PaddleTensor and a ZeroCopyTensor
+void CompareResult(const std::vector<PaddleTensor> &outputs,
+                   const std::vector<ZeroCopyTensor> &ref_outputs) {
+  EXPECT_GT(outputs.size(), 0UL);
+  EXPECT_EQ(outputs.size(), ref_outputs.size());
+  for (size_t i = 0; i < outputs.size(); i++) {
+    auto &out = outputs[i];
+    auto &ref_out = ref_outputs[i];
+    size_t size = VecReduceToInt(out.shape);
+    EXPECT_GT(size, 0UL);
+    int ref_size = 0;  // this is the number of elements not memory size
+    PaddlePlace place;
+    switch (out.dtype) {
+      case PaddleDType::INT64: {
+        int64_t *pdata = static_cast<int64_t *>(out.data.data());
+        int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size);
+        EXPECT_EQ(size, ref_size);
+        for (size_t j = 0; j < size; ++j) {
+          EXPECT_EQ(pdata_ref[j], pdata[j]);
+        }
+        break;
+      }
+      case PaddleDType::FLOAT32: {
+        float *pdata = static_cast<float *>(out.data.data());
+        float *pdata_ref = ref_out.data<float>(&place, &ref_size);
+        EXPECT_EQ(size, ref_size);
+        for (size_t j = 0; j < size; ++j) {
+          CHECK_LE(std::abs(pdata_ref[j] - pdata[j]), FLAGS_accuracy);
+        }
+        break;
+      }
+      case PaddleDType::INT32: {
+        int32_t *pdata = static_cast<int32_t *>(out.data.data());
+        int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
+        EXPECT_EQ(size, ref_size);
+        for (size_t j = 0; j < size; ++j) {
+          EXPECT_EQ(pdata_ref[j], pdata[j]);
+        }
+        break;
+      }
+    }
+  }
+}
+
 std::unique_ptr<PaddlePredictor> CreateTestPredictor(
     const PaddlePredictor::Config *config, bool use_analysis = true) {
   const auto *analysis_config =
@@ -205,61 +251,106 @@ void GetInputPerBatch(const std::vector<std::vector<int64_t>> &in,
   }
 }
 
-void TestOneThreadPrediction(
-    const PaddlePredictor::Config *config,
-    const std::vector<std::vector<PaddleTensor>> &inputs,
-    std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
-  int batch_size = FLAGS_batch_size;
-  int num_times = FLAGS_repeat;
-  auto predictor = CreateTestPredictor(config, use_analysis);
+void ConvertPaddleTensorToZeroCopyTensor(
+    PaddlePredictor *predictor, const std::vector<PaddleTensor> &inputs) {
+  for (size_t i = 0; i < inputs.size(); i++) {
+    auto input = inputs[i];
+    auto tensor = predictor->GetInputTensor(input.name);
+    tensor->Reshape(input.shape);
+    tensor->SetLoD({input.lod});
+    if (input.dtype == PaddleDType::INT64) {
+      ZeroCopyTensorAssignData<int64_t>(tensor.get(), input.data);
+    } else if (input.dtype == PaddleDType::FLOAT32) {
+      ZeroCopyTensorAssignData<float>(tensor.get(), input.data);
+    } else if (input.dtype == PaddleDType::INT32) {
+      ZeroCopyTensorAssignData<int32_t>(tensor.get(), input.data);
+    } else {
+      LOG(ERROR) << "unsupported feed type " << input.dtype;
+    }
+  }
+}
 
-  // warmup run
-  LOG(INFO) << "Warm up run...";
-  {
-    Timer warmup_timer;
-    warmup_timer.tic();
+void PredictionWarmUp(PaddlePredictor *predictor,
+                      const std::vector<std::vector<PaddleTensor>> &inputs,
+                      std::vector<PaddleTensor> *outputs, int num_threads,
+                      int tid) {
+  int batch_size = FLAGS_batch_size;
+  LOG(INFO) << "Running thread " << tid << ", warm up run...";
+  if (FLAGS_zero_copy) {
+    ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]);
+  }
+  Timer warmup_timer;
+  warmup_timer.tic();
+  if (!FLAGS_zero_copy) {
     predictor->Run(inputs[0], outputs, batch_size);
-    PrintTime(batch_size, 1, 1, 0, warmup_timer.toc(), 1);
-    if (FLAGS_profile) {
-      paddle::platform::ResetProfiler();
-    }
+  } else {
+    predictor->ZeroCopyRun();
+  }
+  PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1);
+  if (FLAGS_profile) {
+    paddle::platform::ResetProfiler();
   }
+}
 
-  LOG(INFO) << "Run " << num_times << " times...";
-  {
-    Timer run_timer;
-    run_timer.tic();
+void PredictionRun(PaddlePredictor *predictor,
+                   const std::vector<std::vector<PaddleTensor>> &inputs,
+                   std::vector<PaddleTensor> *outputs, int num_threads,
+                   int tid) {
+  int batch_size = FLAGS_batch_size;
+  int num_times = FLAGS_repeat;
+  LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
+  Timer run_timer;
+  double elapsed_time = 0;
 #ifdef WITH_GPERFTOOLS
-    ProfilerStart("paddle_inference.prof");
+  ProfilerStart("paddle_inference.prof");
 #endif
-    for (int i = 0; i < num_times; i++) {
-      for (size_t j = 0; j < inputs.size(); j++) {
-        predictor->Run(inputs[j], outputs, batch_size);
+  if (!FLAGS_zero_copy) {
+    run_timer.tic();
+    for (size_t i = 0; i < inputs.size(); i++) {
+      for (int j = 0; j < num_times; j++) {
+        predictor->Run(inputs[i], outputs, batch_size);
+      }
+    }
+    elapsed_time = run_timer.toc();
+  } else {
+    for (size_t i = 0; i < inputs.size(); i++) {
+      ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]);
+      run_timer.tic();
+      for (int j = 0; j < num_times; j++) {
+        predictor->ZeroCopyRun();
       }
+      elapsed_time += run_timer.toc();
     }
+  }
 #ifdef WITH_GPERFTOOLS
-    ProfilerStop();
+  ProfilerStop();
 #endif
 
-    double latency = run_timer.toc() / (num_times > 1 ? num_times : 1);
-    PrintTime(batch_size, num_times, 1, 0, latency, inputs.size());
-    if (FLAGS_record_benchmark) {
-      Benchmark benchmark;
-      benchmark.SetName(FLAGS_model_name);
-      benchmark.SetBatchSize(batch_size);
-      benchmark.SetLatency(latency);
-      benchmark.PersistToFile("benchmark_record.txt");
-    }
+  PrintTime(batch_size, num_times, num_threads, tid, elapsed_time / num_times,
+            inputs.size());
+  if (FLAGS_record_benchmark) {
+    Benchmark benchmark;
+    benchmark.SetName(FLAGS_model_name);
+    benchmark.SetBatchSize(batch_size);
+    benchmark.SetLatency(elapsed_time / num_times);
+    benchmark.PersistToFile("benchmark_record.txt");
   }
 }
 
+void TestOneThreadPrediction(
+    const PaddlePredictor::Config *config,
+    const std::vector<std::vector<PaddleTensor>> &inputs,
+    std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
+  auto predictor = CreateTestPredictor(config, use_analysis);
+  PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0);
+  PredictionRun(predictor.get(), inputs, outputs, 1, 0);
+}
+
 void TestMultiThreadPrediction(
     const PaddlePredictor::Config *config,
     const std::vector<std::vector<PaddleTensor>> &inputs,
     std::vector<PaddleTensor> *outputs, int num_threads,
     bool use_analysis = true) {
-  int batch_size = FLAGS_batch_size;
-  int num_times = FLAGS_repeat;
   std::vector<std::thread> threads;
   std::vector<std::unique_ptr<PaddlePredictor>> predictors;
   predictors.emplace_back(CreateTestPredictor(config, use_analysis));
@@ -267,7 +358,6 @@ void TestMultiThreadPrediction(
     predictors.emplace_back(predictors.front()->Clone());
   }
 
-  size_t total_time{0};
   for (int tid = 0; tid < num_threads; ++tid) {
     threads.emplace_back([&, tid]() {
       // Each thread should have local inputs and outputs.
@@ -280,34 +370,8 @@ void TestMultiThreadPrediction(
             ->SetMkldnnThreadID(static_cast<int>(tid) + 1);
       }
 #endif
-
-      // warmup run
-      LOG(INFO) << "Running thread " << tid << ", warm up run...";
-      {
-        Timer warmup_timer;
-        warmup_timer.tic();
-        predictor->Run(inputs[0], outputs, batch_size);
-        PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1);
-        if (FLAGS_profile) {
-          paddle::platform::ResetProfiler();
-        }
-      }
-
-      LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
-      {
-        Timer timer;
-        timer.tic();
-        for (int i = 0; i < num_times; i++) {
-          for (const auto &input : inputs) {
-            ASSERT_TRUE(predictor->Run(input, &outputs_tid));
-          }
-        }
-
-        auto time = timer.toc();
-        total_time += time;
-        PrintTime(batch_size, num_times, num_threads, tid, time / num_times,
-                  inputs.size());
-      }
+      PredictionWarmUp(predictor.get(), inputs, outputs, num_threads, tid);
+      PredictionRun(predictor.get(), inputs, outputs, num_threads, tid);
     });
   }
   for (int i = 0; i < num_threads; ++i) {
@@ -367,6 +431,31 @@ void CompareNativeAndAnalysis(
   CompareResult(analysis_outputs, native_outputs);
 }
 
+void CompareAnalysisAndZeroCopy(
+    PaddlePredictor::Config *config,
+    const std::vector<std::vector<PaddleTensor>> &inputs,
+    const std::vector<std::string> &outputs_name) {
+  int batch_size = FLAGS_batch_size;
+  // analysis
+  std::vector<PaddleTensor> analysis_outputs;
+  auto predictor = CreateTestPredictor(config, true);
+  predictor->Run(inputs[0], &analysis_outputs, batch_size);
+  // analysis + zero_copy
+  std::vector<ZeroCopyTensor> zerocopy_outputs;
+  reinterpret_cast<AnalysisConfig *>(config)->SwitchUseFeedFetchOps(false);
+  predictor = CreateTestPredictor(config, true);
+  ConvertPaddleTensorToZeroCopyTensor(predictor.get(), inputs[0]);
+  predictor->ZeroCopyRun();
+  for (size_t i = 0; i < outputs_name.size(); i++) {
+    ZeroCopyTensor zerocopy_output =
+        *predictor->GetOutputTensor(outputs_name[i]).get();
+    zerocopy_outputs.emplace_back(zerocopy_output);
+    LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(zerocopy_output);
+  }
+  // compare
+  CompareResult(analysis_outputs, zerocopy_outputs);
+}
+
 template <typename T>
 std::string LoDTensorSummary(const framework::LoDTensor &tensor) {
   std::stringstream ss;
diff --git a/paddle/fluid/inference/tests/test.cmake b/paddle/fluid/inference/tests/test.cmake
index 6c5fe043ff..f551b322fe 100644
--- a/paddle/fluid/inference/tests/test.cmake
+++ b/paddle/fluid/inference/tests/test.cmake
@@ -30,19 +30,20 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
       ${EXTERNAL_PROJECT_NAME}
       ${EXTERNAL_PROJECT_LOG_ARGS}
       PREFIX                ${INSTALL_DIR}
-      URL                   ${URL}/${FILENAME}
+      DOWNLOAD_COMMAND      wget -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} &&
+                            ${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME}
       DOWNLOAD_DIR          ${INSTALL_DIR}
       DOWNLOAD_NO_PROGRESS  1
       CONFIGURE_COMMAND     ""
       BUILD_COMMAND         ""
       UPDATE_COMMAND        ""
-      INSTALL_COMMAND       ${CMAKE_COMMAND} -E copy_directory ${UNPACK_DIR} ${INSTALL_DIR}
+      INSTALL_COMMAND       ""
   )
 endfunction()
 
 set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec")
-if (NOT EXISTS ${WORD2VEC_INSTALL_DIR})
-    inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz")
+if(NOT EXISTS ${WORD2VEC_INSTALL_DIR} AND NOT WIN32)
+  inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz")
 endif()
 set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model")
 
diff --git a/paddle/fluid/memory/allocation/legacy_allocator.cc b/paddle/fluid/memory/allocation/legacy_allocator.cc
index 1936f9d4cd..a97d54a191 100644
--- a/paddle/fluid/memory/allocation/legacy_allocator.cc
+++ b/paddle/fluid/memory/allocation/legacy_allocator.cc
@@ -14,6 +14,7 @@
 
 #include "paddle/fluid/memory/allocation/legacy_allocator.h"
 
+#include <memory>
 #include <string>
 #include <utility>
 #include <vector>
diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt
index b614e9b035..7aa1c44eaa 100644
--- a/paddle/fluid/operators/controlflow/CMakeLists.txt
+++ b/paddle/fluid/operators/controlflow/CMakeLists.txt
@@ -1,4 +1,5 @@
 include(operators)
 register_operators(DEPS naive_executor)
+cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator) 
 
 file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc
index 0360cf5273..8352ba4f2b 100644
--- a/paddle/fluid/operators/controlflow/while_op.cc
+++ b/paddle/fluid/operators/controlflow/while_op.cc
@@ -18,6 +18,7 @@
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/framework/operator.h"
 #include "paddle/fluid/framework/var_type.h"
+#include "paddle/fluid/operators/controlflow/while_op_helper.h"
 #include "paddle/fluid/operators/detail/safe_ref.h"
 
 namespace paddle {
@@ -26,14 +27,6 @@ namespace operators {
 using StepScopeVar = std::vector<framework::Scope *>;
 using LoDTensor = framework::LoDTensor;
 
-static constexpr char kStepBlock[] = "sub_block";
-static constexpr char kCondition[] = "Condition";
-static constexpr char kStepScopes[] = "StepScopes";
-static constexpr char kX[] = "X";
-static constexpr char kXGRAD[] = "X@GRAD";
-static constexpr char kOutputs[] = "Out";
-static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
-
 namespace {  // NOLINT
 static std::string GetSkipEagerDeletionVarsDebugString(
     const std::vector<std::string> &vars) {
diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc
new file mode 100644
index 0000000000..2cbd94a061
--- /dev/null
+++ b/paddle/fluid/operators/controlflow/while_op_helper.cc
@@ -0,0 +1,291 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/operators/controlflow/while_op_helper.h"
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include "paddle/fluid/framework/program_desc.h"
+
+namespace paddle {
+namespace operators {
+
+// OpVariant is a wrapper class of OpDesc and OperatorBase
+// So that API would be the same.
+class OpVariant {
+  struct InputsVisitor
+      : public boost::static_visitor<const framework::VariableNameMap *> {
+    template <typename OpType>
+    const framework::VariableNameMap *operator()(const OpType *op) const {
+      return &(op->Inputs());
+    }
+  };
+
+  struct OutputsVisitor
+      : public boost::static_visitor<const framework::VariableNameMap *> {
+    template <typename OpType>
+    const framework::VariableNameMap *operator()(const OpType *op) const {
+      return &(op->Outputs());
+    }
+  };
+
+  struct AttributeMapVisitor
+      : public boost::static_visitor<const framework::AttributeMap *> {
+    const framework::AttributeMap *operator()(
+        const framework::OpDesc *op) const {
+      return &(op->GetAttrMap());
+    }
+
+    const framework::AttributeMap *operator()(
+        const framework::OperatorBase *op) const {
+      return &(op->Attrs());
+    }
+  };
+
+  struct RawPointerVisitor : public boost::static_visitor<const void *> {
+    template <typename OpType>
+    const void *operator()(const OpType *op) const {
+      return op;
+    }
+  };
+
+ public:
+  OpVariant(const framework::OperatorBase *op) : op_(op) {}  // NOLINT
+
+  OpVariant(const framework::OpDesc *op) : op_(op) {}  // NOLINT
+
+  const framework::VariableNameMap &Inputs() const {
+    return *boost::apply_visitor(InputsVisitor(), op_);
+  }
+
+  const framework::VariableNameMap &Outputs() const {
+    return *boost::apply_visitor(OutputsVisitor(), op_);
+  }
+
+  const framework::AttributeMap &Attrs() const {
+    return *boost::apply_visitor(AttributeMapVisitor(), op_);
+  }
+
+  template <typename AttrType>
+  const AttrType &Attr(const std::string &name) const {
+    auto &attrs = Attrs();
+    auto it = attrs.find(name);
+    PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
+    return boost::get<AttrType>(it->second);
+  }
+
+  bool operator==(const OpVariant &other) const {
+    return RawPointer() == other.RawPointer();
+  }
+
+  const void *RawPointer() const {
+    return boost::apply_visitor(RawPointerVisitor(), op_);
+  }
+
+  int which() const { return static_cast<int>(op_.which()); }
+
+  struct Hasher {
+    size_t operator()(const OpVariant &op) const {
+      return reinterpret_cast<size_t>(op.RawPointer());
+    }
+  };
+
+ private:
+  const boost::variant<const framework::OperatorBase *,
+                       const framework::OpDesc *>
+      op_;
+};
+
+static std::string GetDebugString(const std::vector<std::string> &names) {
+  if (names.empty()) return "";
+  std::string ret = names[0];
+  for (size_t i = 1; i < names.size(); ++i) {
+    ret += (" " + names[i]);
+  }
+  return ret;
+}
+
+// Set skip variables of while_op and while_grad_op
+// These variables should be skipped when eager deletion enables.
+// It is because:
+//  1. while_grad_op needs some variables defined in while_op.
+//  2. while_grad_op needs variables from the previous time step.
+static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
+  auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
+  VLOG(2) << "Prepare to skip " << attr.size()
+          << " var(s): " << GetDebugString(attr);
+  attrs[kSkipEagerDeletionVars] = std::move(attr);
+}
+
+// Check whether the forward while_op and while_grad_op match
+// The program may have many while_ops.
+static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
+                                           const OpVariant &grad_op) {
+  return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
+         fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
+}
+
+// Test whether the variable is skippable in forward while_op
+// The variable is skippable in while_op when the variable used in while_grad
+// is not from grad_block.
+static bool IsSkippableVar(const std::string &name,
+                           framework::BlockDesc *grad_block) {
+  return name != framework::kEmptyVarName && !grad_block->HasVar(name);
+}
+
+static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
+                                            const OpVariant &bwd_op) {
+  auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
+
+  // Find all skippable variables in forward while_op
+  std::unordered_set<std::string> forward_skip_vars;
+  for (auto *op_desc : grad_block->AllOps()) {
+    for (auto &in_arg_name : op_desc->InputArgumentNames()) {
+      if (IsSkippableVar(in_arg_name, grad_block)) {
+        forward_skip_vars.insert(in_arg_name);
+      }
+    }
+
+    for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
+      if (IsSkippableVar(out_arg_name, grad_block)) {
+        forward_skip_vars.insert(out_arg_name);
+      }
+    }
+  }
+
+  SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
+                                               forward_skip_vars.end()));
+
+  // Find all skippable variables in while_grad_op
+  // The skipped variables are those which would be used across time steps.
+  auto &fwd_input = fwd_op.Inputs().at(kX);
+  auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
+  PADDLE_ENFORCE_EQ(
+      fwd_input.size(), in_grads.size(),
+      "Backward input gradient number does not match forward input number.");
+
+  std::unordered_set<std::string> backward_skip_vars;
+  for (size_t i = 0; i < in_grads.size(); ++i) {
+    if (in_grads[i] == framework::kEmptyVarName) {
+      continue;
+    }
+    backward_skip_vars.insert(in_grads[i]);
+    backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
+  }
+
+  SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
+                                               backward_skip_vars.end()));
+}
+
+// Find all while_ops and while_grad_ops in the graph or program
+// The while_grad_op and while_op may located in different blocks
+// So we should traverse all blocks in the program and find them out.
+static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
+                                       std::vector<OpVariant> *while_grad_ops) {
+  PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
+
+  if (while_ops->empty()) return;
+
+  const auto *program =
+      while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
+  for (size_t i = 1; i < program->Size(); ++i) {
+    auto &block = program->Block(i);
+    for (size_t j = 0; j < block.OpSize(); ++j) {
+      auto *op = block.Op(j);
+      if (op->Type() == "while") {
+        while_ops->emplace_back(op);
+      } else if (op->Type() == "while_grad") {
+        while_grad_ops->emplace_back(op);
+      }
+    }
+  }
+
+  PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
+                    "There are extra while_grad ops in the graph or program");
+}
+
+static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
+    std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
+  FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
+
+  VLOG(2) << "Found while op num: " << while_ops->size()
+          << ", while grad op num: " << while_grad_ops->size();
+
+  if (while_grad_ops->empty()) {
+    return;
+  }
+
+  std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
+      while_ops->begin(), while_ops->end());
+
+  for (auto &bwd_op : *while_grad_ops) {
+    const OpVariant *matched_fwd_op = nullptr;
+    for (auto &fwd_op : while_op_set) {
+      if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
+        PADDLE_ENFORCE(matched_fwd_op == nullptr,
+                       "Found multiple matched while ops");
+        matched_fwd_op = &fwd_op;
+      }
+    }
+    PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
+                            "Cannot find matched forward while op.");
+    ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
+    while_op_set.erase(*matched_fwd_op);
+  }
+}
+
+void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
+    int block_id,
+    const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
+  // If block_id is not 0, returns
+  // This is because all while_ops and while_grad_ops in the whole program
+  // would be processed when block_id is 0 (i.e. when Executor::Run() or
+  // ParallelExecutor constructs).
+
+  // What's more, all while_ops and while_grad_ops must be processed when
+  // block_id is zero. If not, while_op may run first and erase variables
+  // used in while_grad_op, and in this moment, while_grad_ops may be not
+  // constructed yet.
+  if (block_id != 0) return;
+
+  std::vector<OpVariant> fwd_ops, bwd_ops;
+  for (auto &op : all_ops) {
+    if (op->Type() == "while") {
+      fwd_ops.emplace_back(op.get());
+    } else if (op->Type() == "while_grad") {
+      bwd_ops.emplace_back(op.get());
+    }
+  }
+  PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
+}
+
+void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
+    const std::vector<framework::OperatorBase *> &while_ops,
+    const std::vector<framework::OperatorBase *> &while_grad_ops) {
+  std::vector<OpVariant> fwd_ops, bwd_ops;
+  fwd_ops.reserve(while_ops.size());
+  for (auto *op : while_ops) {
+    fwd_ops.emplace_back(op);
+  }
+
+  bwd_ops.reserve(while_grad_ops.size());
+  for (auto *op : while_grad_ops) {
+    bwd_ops.emplace_back(op);
+  }
+
+  PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
+}
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/controlflow/while_op_helper.h b/paddle/fluid/operators/controlflow/while_op_helper.h
new file mode 100644
index 0000000000..456ba8642b
--- /dev/null
+++ b/paddle/fluid/operators/controlflow/while_op_helper.h
@@ -0,0 +1,43 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/platform/variant.h"
+
+namespace paddle {
+namespace operators {
+
+static constexpr char kStepBlock[] = "sub_block";
+static constexpr char kCondition[] = "Condition";
+static constexpr char kStepScopes[] = "StepScopes";
+static constexpr char kX[] = "X";
+static constexpr char kXGRAD[] = "X@GRAD";
+static constexpr char kOutputs[] = "Out";
+static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
+
+void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
+    int block_id,
+    const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
+
+void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
+    const std::vector<framework::OperatorBase *> &while_ops,
+    const std::vector<framework::OperatorBase *> &while_grad_ops);
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h
index 72774a878d..d6b54038ec 100644
--- a/paddle/fluid/operators/crf_decoding_op.h
+++ b/paddle/fluid/operators/crf_decoding_op.h
@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
     Tensor track;
     int* track_value =
         track.mutable_data<int>(emission_dims, platform::CPUPlace());
-    auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
-                        platform::CPUPlace>(tag_num);
+    auto ker =
+        jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
+            .At(tag_num);
     ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
     T max_score = -std::numeric_limits<T>::max();
     int max_i = 0;
diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
index 04e8800bbc..f2f4d3fee0 100644
--- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
+++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
@@ -110,8 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
         constexpr int simd_width = 16;
         int C = c / simd_width;
 
-        auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
-                                 platform::CPUPlace>(0);
+        auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
+                                         platform::CPUPlace>::Cache()
+                            .At(0);
 #pragma omp parallel for collapse(2)
         for (int ni = 0; ni < n; ni++) {
           for (int ci = 0; ci < C; ci++) {
diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
index f13c020386..5e2e336e71 100644
--- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
+++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
@@ -52,8 +52,9 @@ struct EmbeddingVSumFunctor {
                                   out_width, jit::SeqPoolType::kSum);
     for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
       attr.index_height = ids_lod[i + 1] - ids_lod[i];
-      auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
-                                  platform::CPUPlace>(attr);
+      auto emb_seqpool =
+          jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
+              .At(attr);
       emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
                   &attr);
     }
@@ -135,8 +136,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
       T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
       const T *d_output_data = d_output->data<T>();
 
-      auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>,
-                                 platform::CPUPlace>(out_width);
+      auto vbroadcast =
+          jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
+              .At(out_width);
       for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
         int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
         const T *src = d_output_data + i * out_width;
diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc
index 66acba49e5..ba5f0747c4 100644
--- a/paddle/fluid/operators/fused/fusion_gru_op.cc
+++ b/paddle/fluid/operators/fused/fusion_gru_op.cc
@@ -182,29 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> {
   const int total_T = x_dims[0];           \
   const int D3 = wh_dims[1]
 
-#define INIT_OTHER_DEFINES                                                     \
-  auto* h0 = ctx.Input<Tensor>("H0");                                          \
-  auto* wx = ctx.Input<Tensor>("WeightX");                                     \
-  auto* bias = ctx.Input<Tensor>("Bias");                                      \
-  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");                          \
-  bool is_reverse = ctx.Attr<bool>("is_reverse");                              \
-  const int M = x_dims[1];                                                     \
-  const int D = wh_dims[0];                                                    \
-  const int D2 = D * 2;                                                        \
-  const jit::gru_attr_t attr(                                                  \
-      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),         \
-      jit::to_kerneltype(ctx.Attr<std::string>("activation")));                \
-  jit::gru_t one_step;                                                         \
-  auto ComputeH1 =                                                             \
-      jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr);      \
-  auto ComputeHtPart1 =                                                        \
-      jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
-  auto ComputeHtPart2 =                                                        \
-      jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
-  const T* x_data = x->data<T>();                                              \
-  const T* wx_data = wx->data<T>();                                            \
-  const T* wh_data = wh->data<T>();                                            \
-  auto place = ctx.GetPlace();                                                 \
+#define INIT_OTHER_DEFINES                                                   \
+  auto* h0 = ctx.Input<Tensor>("H0");                                        \
+  auto* wx = ctx.Input<Tensor>("WeightX");                                   \
+  auto* bias = ctx.Input<Tensor>("Bias");                                    \
+  auto* hidden_out = ctx.Output<LoDTensor>("Hidden");                        \
+  bool is_reverse = ctx.Attr<bool>("is_reverse");                            \
+  const int M = x_dims[1];                                                   \
+  const int D = wh_dims[0];                                                  \
+  const int D2 = D * 2;                                                      \
+  const jit::gru_attr_t attr(                                                \
+      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),       \
+      jit::to_kerneltype(ctx.Attr<std::string>("activation")));              \
+  jit::gru_t one_step;                                                       \
+  auto ComputeH1 =                                                           \
+      jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At(  \
+          attr);                                                             \
+  auto ComputeHtPart1 =                                                      \
+      jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
+          .At(attr);                                                         \
+  auto ComputeHtPart2 =                                                      \
+      jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
+          .At(attr);                                                         \
+  const T* x_data = x->data<T>();                                            \
+  const T* wx_data = wx->data<T>();                                          \
+  const T* wh_data = wh->data<T>();                                          \
+  auto place = ctx.GetPlace();                                               \
   T* xx_data = xx->mutable_data<T>(place)
 
   void SeqCompute(const framework::ExecutionContext& ctx) const {
diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc
index b11b7c11bf..c8c07bd126 100644
--- a/paddle/fluid/operators/fused/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc
@@ -235,32 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
   const int D = wh_dims[0];                                 \
   const int D4 = wh_dims[1]
 
-#define INIT_OTHER_DEFINES                                                    \
-  const T* x_data = x->data<T>();                                             \
-  const T* wx_data = wx->data<T>();                                           \
-  const T* wh_data = wh->data<T>();                                           \
-  /* diagonal weight*/                                                        \
-  const T* wp_data = bias->data<T>() + D4;                                    \
-  /* for peephole only*/                                                      \
-  T* checked_cell_data = nullptr;                                             \
-  auto place = ctx.GetPlace();                                                \
-  if (use_peepholes) {                                                        \
-    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                          \
-    auto* checked_cell = ctx.Output<Tensor>("CheckedCell");                   \
-    checked_cell_data = checked_cell->mutable_data<T>(place);                 \
-  }                                                                           \
-  const jit::lstm_attr_t attr(                                                \
-      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),        \
-      jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")),      \
-      jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")),           \
-      use_peepholes);                                                         \
-  jit::lstm_t one_step;                                                       \
-  one_step.wp = wp_data;                                                      \
-  one_step.checked = checked_cell_data;                                       \
-  auto ComputeC1H1 =                                                          \
-      jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
-  auto ComputeCtHt =                                                          \
-      jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
+#define INIT_OTHER_DEFINES                                                     \
+  const T* x_data = x->data<T>();                                              \
+  const T* wx_data = wx->data<T>();                                            \
+  const T* wh_data = wh->data<T>();                                            \
+  /* diagonal weight*/                                                         \
+  const T* wp_data = bias->data<T>() + D4;                                     \
+  /* for peephole only*/                                                       \
+  T* checked_cell_data = nullptr;                                              \
+  auto place = ctx.GetPlace();                                                 \
+  if (use_peepholes) {                                                         \
+    /* w_ic * Ct-1, w_fc * Ct-1  ; w_oc * Ct => ih*/                           \
+    auto* checked_cell = ctx.Output<Tensor>("CheckedCell");                    \
+    checked_cell_data = checked_cell->mutable_data<T>(place);                  \
+  }                                                                            \
+  const jit::lstm_attr_t attr(                                                 \
+      D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")),         \
+      jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")),       \
+      jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")),            \
+      use_peepholes);                                                          \
+  jit::lstm_t one_step;                                                        \
+  one_step.wp = wp_data;                                                       \
+  one_step.checked = checked_cell_data;                                        \
+  auto ComputeC1H1 =                                                           \
+      jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
+          attr);                                                               \
+  auto ComputeCtHt =                                                           \
+      jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
+          attr)
 
 // Wh GEMM
 #define GEMM_WH_ADDON(bs, prev, out)                                           \
diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
index 8ecdf2ed9d..6be35de65f 100644
--- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
+++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc
@@ -82,9 +82,11 @@ template <typename T>
 static void fc_relu(const T* x, const T* w, const T* b, T* y,
                     const jit::matmul_attr_t& attr) {
   auto matmul =
-      jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
+      jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
+          attr);
   auto addbias_relu =
-      jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n);
+      jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
+          attr.n);
   matmul(x, w, y, &attr);
   T* dst = y;
   for (int i = 0; i < attr.m; ++i) {
diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc
index d48bdafe0a..25916768c0 100644
--- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc
+++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc
@@ -98,7 +98,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
       attr.type = jit::SeqPoolType::kSqrt;
     }
     auto seqpool =
-        jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
+        jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
             attr);
     size_t n = ins.size();
     size_t dst_step_size = n * w;
diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
index 8493f4468f..53679ebdde 100644
--- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
+++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc
@@ -94,19 +94,23 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
     int o_numel = attr.m * attr.n;
 
     auto vsquare_x =
-        jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m *
-                                                                       attr.k);
+        jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
+            attr.m * attr.k);
     auto vsquare_y =
-        jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k *
-                                                                       attr.n);
+        jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
+            attr.k * attr.n);
     auto vsquare_xy =
-        jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
+        jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
+            o_numel);
     auto vsub =
-        jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
+        jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
+            o_numel);
     auto vscal =
-        jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
+        jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
+            o_numel);
     auto matmul =
-        jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
+        jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
+            attr);
 
     const T* x_data = x->data<T>();
     const T* y_data = y->data<T>();
diff --git a/paddle/fluid/operators/jit/CMakeLists.txt b/paddle/fluid/operators/jit/CMakeLists.txt
index 35775d7ec9..47d6c83f2a 100644
--- a/paddle/fluid/operators/jit/CMakeLists.txt
+++ b/paddle/fluid/operators/jit/CMakeLists.txt
@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
 file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
 file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
 
-set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
+set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash)
 
 file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
 list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc
index 3088280bb9..fbb04a166e 100644
--- a/paddle/fluid/operators/jit/benchmark.cc
+++ b/paddle/fluid/operators/jit/benchmark.cc
@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
       InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_());     \
   void BenchJITKernel_##name##_##dtype##_##place##_::Run()
 
-#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
-
 void RUN_ALL_BENCHMARK() {
   for (auto p : g_all_benchmarks) {
     if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
@@ -90,11 +88,11 @@ std::vector<int> TestSizes() {
   return s;
 }
 
-template <typename KernelTuples, typename... Args>
+template <typename KernelTuple, typename... Args>
 struct BenchFunc {
   // return this function avg time
   // TODO(TJ): clear cache every time
-  double operator()(const typename KernelTuples::func_type tgt, Args... args) {
+  double operator()(const typename KernelTuple::func_type tgt, Args... args) {
     for (int i = 0; i < FLAGS_burning; ++i) {
       tgt(args...);
     }
@@ -109,40 +107,17 @@ struct BenchFunc {
 
 namespace jit = paddle::operators::jit;
 
-template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
-          typename... Args>
-void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
-  BenchFunc<KernelTuples, Args...> benchmark;
+template <typename KernelTuple, typename PlaceType, typename... Args>
+void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
+  BenchFunc<KernelTuple, Args...> benchmark;
   std::vector<std::pair<std::string, double>> infos;
-  // test refer
-  auto refer = jit::GetRefer<KT, KernelTuples>();
-  if (!refer) {
-    LOG(FATAL) << "Refer can not be empty!";
+  auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
+  for (auto f : funcs) {
+    infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
   }
-  infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
 
-  // test jitcode
-  auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
-  if (jitcode) {
-    infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
-  }
-  // test all impls in more
-  jit::KernelKey kkey(KT, PlaceType());
-  auto& pool = jit::KernelPool().Instance().AllKernels();
-  auto iter = pool.find(kkey);
-  if (iter != pool.end()) {
-    auto& impls = iter->second;
-    for (auto& impl : impls) {
-      auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
-      if (i && i->UseMe(attr)) {
-        auto more = i->GetFunc();
-        infos.push_back(
-            std::make_pair(i->ImplType(), benchmark(more, args...)));
-      }
-    }
-  }
   // Test result from Get function
-  auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
+  auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
   if (!tgt) {
     LOG(FATAL) << "Target can not be empty!";
   }
@@ -150,7 +125,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
 
   // print
   std::ostringstream loginfos;
-  loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
+  loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": "
+           << attr << ": ";
   for (auto pair : infos) {
     loginfos << pair.first << " takes " << pair.second << " us; ";
   }
@@ -159,8 +135,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
 
 using Tensor = paddle::framework::Tensor;
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchXYZNKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelXYZN() {
+  using T = typename KernelTuple::data_type;
   for (int d : TestSizes()) {
     Tensor x, y, z;
     x.Resize({d});
@@ -171,16 +148,16 @@ void BenchXYZNKernel() {
     T* z_data = z.mutable_data<T>(PlaceType());
     RandomVec<T>(d, x_data);
     RandomVec<T>(d, y_data);
-    BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
-                                                     y.data<T>(), z_data, d);
+    BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y.data<T>(), z_data,
+                                          d);
     // test inplace
-    BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
-                                                     z_data, d);
+    BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), z_data, z_data, d);
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchAXYNKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelAXYN() {
+  using T = typename KernelTuple::data_type;
   for (int d : TestSizes()) {
     const T a = static_cast<T>(3);
     Tensor x, y;
@@ -189,26 +166,26 @@ void BenchAXYNKernel() {
     T* x_data = x.mutable_data<T>(PlaceType());
     T* y_data = y.mutable_data<T>(PlaceType());
     RandomVec<T>(d, x_data);
-    BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
-                                                     d);
+    BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), y_data, d);
     // test inplace
-    BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data,
-                                                     d);
+    BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), x_data, d);
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchXRNKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelXRN() {
+  using T = typename KernelTuple::data_type;
   for (int d : TestSizes()) {
     Tensor x;
     RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
     T res;
-    BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
+    BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), &res, d);
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchXYNKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelXYN() {
+  using T = typename KernelTuple::data_type;
   for (int d : TestSizes()) {
     Tensor x, y;
     x.Resize({d});
@@ -216,12 +193,13 @@ void BenchXYNKernel() {
     T* x_data = x.mutable_data<T>(PlaceType());
     T* y_data = y.mutable_data<T>(PlaceType());
     RandomVec<T>(d, x_data);
-    BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d);
+    BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y_data, d);
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchLSTMKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelLSTM() {
+  using T = typename KernelTuple::data_type;
   for (bool use_peephole : {true, false}) {
     for (int d : TestSizes()) {
       const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
@@ -252,13 +230,14 @@ void BenchLSTMKernel() {
         step.wp = wp_data;
         step.checked = checked_data;
       }
-      BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
+      BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchGRUKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelGRU() {
+  using T = typename KernelTuple::data_type;
   for (int d : TestSizes()) {
     const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
     auto place = PlaceType();
@@ -275,12 +254,13 @@ void BenchGRUKernel() {
     step.gates = x_data;
     step.ht_1 = ht_1_data;
     step.ht = ht_data;
-    BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
+    BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchSeqPoolKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelSeqPool() {
+  using T = typename KernelTuple::data_type;
   std::vector<jit::SeqPoolType> pool_types = {
       jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
   for (auto type : pool_types) {
@@ -294,15 +274,15 @@ void BenchSeqPoolKernel() {
         RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
         const T* x_data = x.data<T>();
         T* y_data = y.mutable_data<T>(PlaceType());
-        BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
-                                                            y_data, &attr);
+        BenchAllImpls<KernelTuple, PlaceType>(attr, x_data, y_data, &attr);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchEmbSeqPoolKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelEmbSeqPool() {
+  using T = typename KernelTuple::data_type;
   std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
   int64_t tbl_h = 1e4;
   for (int tbl_w : {10, 16, 256}) {
@@ -324,16 +304,17 @@ void BenchEmbSeqPoolKernel() {
                              tbl_h - 1);
           const int64_t* idx_data = idx.data<int64_t>();
           T* o_data = out.mutable_data<T>(PlaceType());
-          BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>(
-              attr, table_data, idx_data, o_data, &attr);
+          BenchAllImpls<KernelTuple, PlaceType>(attr, table_data, idx_data,
+                                                o_data, &attr);
         }
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchSgdKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelSgd() {
+  using T = typename KernelTuple::data_type;
   const T lr = 0.1;
   auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
                                   const int64_t upper) -> std::vector<int64_t> {
@@ -364,15 +345,16 @@ void BenchSgdKernel() {
         const T* grad_data = grad.data<T>();
         const int64_t* rows_data = rows.data();
         jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
-        BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>(
-            attr, &lr, param_data, grad_data, rows_data, param_data, &attr);
+        BenchAllImpls<KernelTuple, PlaceType>(attr, &lr, param_data, grad_data,
+                                              rows_data, param_data, &attr);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchMatMulKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelMatMul() {
+  using T = typename KernelTuple::data_type;
   for (int m : {1, 2, 3, 4}) {
     for (int n : TestSizes()) {
       for (int k : TestSizes()) {
@@ -386,15 +368,16 @@ void BenchMatMulKernel() {
         const T* b_data = b.data<T>();
         T* c_data = c.mutable_data<T>(PlaceType());
         const jit::matmul_attr_t attr{m, n, k};
-        BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data,
-                                                           c_data, &attr);
+        BenchAllImpls<KernelTuple, PlaceType>(attr, a_data, b_data, c_data,
+                                              &attr);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchSoftmaxKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelSoftmax() {
+  using T = typename KernelTuple::data_type;
   for (int bs : {1, 2, 10}) {
     for (int n : TestSizes()) {
       Tensor x, y;
@@ -403,14 +386,14 @@ void BenchSoftmaxKernel() {
       RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
       const T* x_data = x.data<T>();
       T* y_data = y.mutable_data<T>(PlaceType());
-      BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
-                                                          bs);
+      BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchLayerNormKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelLayerNorm() {
+  using T = typename KernelTuple::data_type;
   const T epsilon = 9.99999975e-06;
   for (int n : {1, 2, 10}) {
     for (int x_dim_0 : {1, 9, 17, 50}) {
@@ -439,16 +422,17 @@ void BenchLayerNormKernel() {
         T* var_data = var.data<T>();
         T* out_data = out.mutable_data<T>(PlaceType());
 
-        BenchAllImpls<KT, jit::LayerNormTuples<T>, PlaceType>(
-            right, x_data, out_data, mean_data, var_data, scale_data, bias_data,
-            left, epsilon, right);
+        BenchAllImpls<KernelTuple, PlaceType>(right, x_data, out_data,
+                                              mean_data, var_data, scale_data,
+                                              bias_data, left, epsilon, right);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchCRFDecodingKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelCRFDecoding() {
+  using T = typename KernelTuple::data_type;
   constexpr int state_trans_base_idx = 2;
   for (int seq_len : {1, 11, 17, 50}) {
     for (int tag_num : TestSizes()) {
@@ -468,14 +452,15 @@ void BenchCRFDecodingKernel() {
       T* alpha_data = alpha.mutable_data<T>(PlaceType());
       int* track_data = track.mutable_data<int>(PlaceType());
 
-      BenchAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType>(
-          tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num);
+      BenchAllImpls<KernelTuple, PlaceType>(tag_num, seq_len, x_data, w_data,
+                                            alpha_data, track_data, tag_num);
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void BenchVBroadcastKernel() {
+template <typename KernelTuple, typename PlaceType>
+void BenchKernelVBroadcast() {
+  using T = typename KernelTuple::data_type;
   for (int64_t w : {1, 16, 64, 100, 256}) {
     Tensor x;
     x.Resize({w});
@@ -485,78 +470,86 @@ void BenchVBroadcastKernel() {
       Tensor y;
       y.Resize({h * w});
       T* y_data = y.mutable_data<T>(PlaceType());
-      BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>(
-          w, x_data, y_data, static_cast<int64_t>(h), w);
+      BenchAllImpls<KernelTuple, PlaceType>(w, x_data, y_data,
+                                            static_cast<int64_t>(h), w);
     }
   }
 }
 
-using T = float;
-using CPUPlace = paddle::platform::CPUPlace;
+#define BenchKernelVMul BenchKernelXYZN
+#define BenchKernelVAdd BenchKernelXYZN
+#define BenchKernelVAddRelu BenchKernelXYZN
+#define BenchKernelVSub BenchKernelXYZN
 
-// xyzn
-BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
+#define BenchKernelVScal BenchKernelAXYN
+#define BenchKernelVAddBias BenchKernelAXYN
 
-// axyn
-BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
+#define BenchKernelVRelu BenchKernelXYN
+#define BenchKernelVIdentity BenchKernelXYN
+#define BenchKernelVSquare BenchKernelXYN
+#define BenchKernelVExp BenchKernelXYN
+#define BenchKernelVSigmoid BenchKernelXYN
+#define BenchKernelVTanh BenchKernelXYN
+#define BenchKernelVCopy BenchKernelXYN
 
-// xrn
-BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
-BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
+#define BenchKernelHMax BenchKernelXRN
+#define BenchKernelHSum BenchKernelXRN
 
-// xyn
-BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
-BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
-
-// lstm and peephole
-BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
-BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
-
-// gru functions
-BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
-BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
-BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
-
-// seq pool function
-BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
-
-// embedding seq pool function
-BENCH_FP32_CPU(kEmbSeqPool) {
-  BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
-}
+#define BenchKernelLSTMCtHt BenchKernelLSTM
+#define BenchKernelLSTMC1H1 BenchKernelLSTM
 
-// sgd function
-BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); }
+#define BenchKernelGRUH1 BenchKernelGRU
+#define BenchKernelGRUHtPart1 BenchKernelGRU
+#define BenchKernelGRUHtPart2 BenchKernelGRU
 
-// matmul
-BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
+using CPUPlace = paddle::platform::CPUPlace;
 
-// softmax
-BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
+#define BENCH_FP32_CPU(name)                                \
+  BENCH_JITKERNEL(name, FP32, CPU) {                        \
+    BenchKernel##name<jit::name##Tuple<float>, CPUPlace>(); \
+  }
 
-// layernorm
-BENCH_FP32_CPU(kLayerNorm) {
-  BenchLayerNormKernel<jit::kLayerNorm, T, CPUPlace>();
-}
+// xyzn
+BENCH_FP32_CPU(VMul);
+BENCH_FP32_CPU(VAdd);
+BENCH_FP32_CPU(VAddRelu);
+BENCH_FP32_CPU(VSub);
 
-// crfdecoding
-BENCH_FP32_CPU(kCRFDecoding) {
-  BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>();
-}
+// axyn
+BENCH_FP32_CPU(VScal);
+BENCH_FP32_CPU(VAddBias);
 
-// vbroadcast function
-BENCH_FP32_CPU(kVBroadcast) {
-  BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>();
-}
+// xyn
+BENCH_FP32_CPU(VRelu);
+BENCH_FP32_CPU(VIdentity);
+BENCH_FP32_CPU(VSquare);
+BENCH_FP32_CPU(VExp);
+BENCH_FP32_CPU(VSigmoid);
+BENCH_FP32_CPU(VTanh);
+BENCH_FP32_CPU(VCopy);
+
+// xrn
+BENCH_FP32_CPU(HMax);
+BENCH_FP32_CPU(HSum);
+
+// LSTM
+BENCH_FP32_CPU(LSTMCtHt);
+BENCH_FP32_CPU(LSTMC1H1);
+
+// GRU
+BENCH_FP32_CPU(GRUH1);
+BENCH_FP32_CPU(GRUHtPart1);
+BENCH_FP32_CPU(GRUHtPart2);
+
+BENCH_FP32_CPU(LayerNorm);
+BENCH_FP32_CPU(CRFDecoding);
+
+BENCH_FP32_CPU(SeqPool);
+BENCH_FP32_CPU(EmbSeqPool);
+BENCH_FP32_CPU(MatMul);
+BENCH_FP32_CPU(Softmax);
+BENCH_FP32_CPU(Sgd);
+BENCH_FP32_CPU(VBroadcast);
 
 // Benchmark all jit kernels including jitcode, mkl and refer.
 // To use this tool, run command: ./benchmark [options...]
diff --git a/paddle/fluid/operators/jit/gen/act.cc b/paddle/fluid/operators/jit/gen/act.cc
index e7a7375879..ad68e792c7 100644
--- a/paddle/fluid/operators/jit/gen/act.cc
+++ b/paddle/fluid/operators/jit/gen/act.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/gen/act.h"
+#include <memory>
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -81,7 +82,7 @@ void VActJitCode::genCode() {
 #define DECLARE_ACT_CREATOR(name)                                            \
   class name##Creator : public JitCodeCreator<int> {                         \
    public:                                                                   \
-    bool UseMe(const int& attr) const override;                              \
+    bool CanBeUsed(const int& attr) const override;                          \
     size_t CodeSize(const int& d) const override;                            \
     std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
       return make_unique<name##JitCode>(attr, CodeSize(attr));               \
@@ -96,27 +97,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
 DECLARE_ACT_CREATOR(VTanh);
 
 // TODO(TJ): tuning use me
-bool VReluCreator::UseMe(const int& d) const {
+bool VReluCreator::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx);
 }
 
-bool VSquareCreator::UseMe(const int& d) const {
+bool VSquareCreator::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx);
 }
 
-bool VIdentityCreator::UseMe(const int& d) const {
+bool VIdentityCreator::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx);
 }
 
-bool VExpCreator::UseMe(const int& d) const {
+bool VExpCreator::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx) && d < 32;
 }
 
-bool VSigmoidCreator::UseMe(const int& d) const {
+bool VSigmoidCreator::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx);
 }
 
-bool VTanhCreator::UseMe(const int& d) const {
+bool VTanhCreator::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx);
 }
 
diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc
index 5da24c359e..c126b9077a 100644
--- a/paddle/fluid/operators/jit/gen/blas.cc
+++ b/paddle/fluid/operators/jit/gen/blas.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/gen/blas.h"
+#include <memory>
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -142,7 +143,7 @@ void NCHW16CMulNCJitCode::genCode() {
 
 class NCHW16CMulNCCreator : public JitCodeCreator<int> {
  public:
-  bool UseMe(const int& attr) const override {
+  bool CanBeUsed(const int& attr) const override {
     return platform::MayIUse(platform::avx512f);
   }
   size_t CodeSize(const int& d) const override { return 256 * 1024; }
@@ -154,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
 #define DECLARE_BLAS_CREATOR(name)                                           \
   class name##Creator : public JitCodeCreator<int> {                         \
    public:                                                                   \
-    bool UseMe(const int& attr) const override {                             \
+    bool CanBeUsed(const int& attr) const override {                         \
       return platform::MayIUse(platform::avx) && attr <= 1024;               \
     }                                                                        \
     size_t CodeSize(const int& d) const override {                           \
diff --git a/paddle/fluid/operators/jit/gen/embseqpool.cc b/paddle/fluid/operators/jit/gen/embseqpool.cc
index 23837a3fb9..331a4b0d07 100644
--- a/paddle/fluid/operators/jit/gen/embseqpool.cc
+++ b/paddle/fluid/operators/jit/gen/embseqpool.cc
@@ -14,6 +14,7 @@
 
 #include "paddle/fluid/operators/jit/gen/embseqpool.h"
 #include <stddef.h>  // offsetof
+#include <memory>
 #include <vector>
 #include "paddle/fluid/operators/jit/gen/act.h"  // for exp_float_consts ones
 #include "paddle/fluid/operators/jit/registry.h"
@@ -121,7 +122,7 @@ void EmbSeqPoolJitCode::genCode() {
 
 class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
  public:
-  bool UseMe(const emb_seq_pool_attr_t& attr) const override {
+  bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
     return platform::MayIUse(platform::avx) &&
            attr.table_width % YMM_FLOAT_BLOCK == 0;
   }
diff --git a/paddle/fluid/operators/jit/gen/gru.cc b/paddle/fluid/operators/jit/gen/gru.cc
index 13f7a14111..b5b0cffa80 100644
--- a/paddle/fluid/operators/jit/gen/gru.cc
+++ b/paddle/fluid/operators/jit/gen/gru.cc
@@ -14,6 +14,7 @@
 
 #include "paddle/fluid/operators/jit/gen/gru.h"
 #include <stddef.h>  // offsetof
+#include <memory>
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -86,7 +87,7 @@ void GRUJitCode::genCode() {
   class name##Creator : public JitCodeCreator<gru_attr_t> {       \
    public:                                                        \
     /* TODO(TJ): enable more */                                   \
-    bool UseMe(const gru_attr_t& attr) const override {           \
+    bool CanBeUsed(const gru_attr_t& attr) const override {       \
       return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
     }                                                             \
     size_t CodeSize(const gru_attr_t& attr) const override {      \
diff --git a/paddle/fluid/operators/jit/gen/hopv.cc b/paddle/fluid/operators/jit/gen/hopv.cc
index e788401719..462ac68a93 100644
--- a/paddle/fluid/operators/jit/gen/hopv.cc
+++ b/paddle/fluid/operators/jit/gen/hopv.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/gen/hopv.h"
+#include <memory>
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -76,7 +77,7 @@ void HOPVJitCode::genCode() {
 #define DECLARE_HOP_CREATOR(name)                                            \
   class name##Creator : public JitCodeCreator<int> {                         \
    public:                                                                   \
-    bool UseMe(const int& attr) const override {                             \
+    bool CanBeUsed(const int& attr) const override {                         \
       return platform::MayIUse(platform::avx);                               \
     }                                                                        \
     size_t CodeSize(const int& d) const override {                           \
diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h
index 39847d1b65..228db7cc72 100644
--- a/paddle/fluid/operators/jit/gen/jitcode.h
+++ b/paddle/fluid/operators/jit/gen/jitcode.h
@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
   virtual void genCode() = 0;
 
   size_t getSize() const override { return CodeGenerator::getSize(); }
-  const unsigned char* getCodeInternal() override {
+  const unsigned char* getCodeInternal() const override {
     const Xbyak::uint8* code = CodeGenerator::getCode();
     return code;
   }
diff --git a/paddle/fluid/operators/jit/gen/lstm.cc b/paddle/fluid/operators/jit/gen/lstm.cc
index 08bafb5a81..2c3bc985e9 100644
--- a/paddle/fluid/operators/jit/gen/lstm.cc
+++ b/paddle/fluid/operators/jit/gen/lstm.cc
@@ -14,6 +14,7 @@
 
 #include "paddle/fluid/operators/jit/gen/lstm.h"
 #include <stddef.h>  // offsetof
+#include <memory>
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -114,7 +115,7 @@ void LSTMJitCode::genCode() {
   class name##Creator : public JitCodeCreator<lstm_attr_t> {      \
    public:                                                        \
     /* TODO(TJ): enable more */                                   \
-    bool UseMe(const lstm_attr_t& attr) const override {          \
+    bool CanBeUsed(const lstm_attr_t& attr) const override {      \
       return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
     }                                                             \
     size_t CodeSize(const lstm_attr_t& attr) const override {     \
diff --git a/paddle/fluid/operators/jit/gen/matmul.cc b/paddle/fluid/operators/jit/gen/matmul.cc
index ae3858eab2..d9955c8cc6 100644
--- a/paddle/fluid/operators/jit/gen/matmul.cc
+++ b/paddle/fluid/operators/jit/gen/matmul.cc
@@ -14,8 +14,8 @@
 
 #include "paddle/fluid/operators/jit/gen/matmul.h"
 #include <stddef.h>  // offsetof
+#include <memory>
 #include <vector>
-
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
 
 class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
  public:
-  bool UseMe(const matmul_attr_t& attr) const override {
+  bool CanBeUsed(const matmul_attr_t& attr) const override {
     return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
            attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
   }
diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc
index 530d24ee1f..d9e5904add 100644
--- a/paddle/fluid/operators/jit/gen/seqpool.cc
+++ b/paddle/fluid/operators/jit/gen/seqpool.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/gen/seqpool.h"
+#include <memory>
 #include "paddle/fluid/operators/jit/gen/act.h"  // for exp_float_consts ones
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
@@ -57,7 +58,7 @@ void SeqPoolJitCode::genCode() {
 
 class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
  public:
-  bool UseMe(const seq_pool_attr_t& attr) const override {
+  bool CanBeUsed(const seq_pool_attr_t& attr) const override {
     return platform::MayIUse(platform::avx);
   }
   size_t CodeSize(const seq_pool_attr_t& attr) const override {
diff --git a/paddle/fluid/operators/jit/gen/sgd.cc b/paddle/fluid/operators/jit/gen/sgd.cc
index a745a27f95..e65d3500b4 100644
--- a/paddle/fluid/operators/jit/gen/sgd.cc
+++ b/paddle/fluid/operators/jit/gen/sgd.cc
@@ -14,6 +14,7 @@
 
 #include "paddle/fluid/operators/jit/gen/sgd.h"
 #include <stddef.h>  // offsetof
+#include <memory>
 #include <vector>
 #include "paddle/fluid/operators/jit/registry.h"
 #include "paddle/fluid/platform/cpu_info.h"
@@ -104,7 +105,7 @@ void SgdJitCode::genCode() {
 
 class SgdCreator : public JitCodeCreator<sgd_attr_t> {
  public:
-  bool UseMe(const sgd_attr_t& attr) const override {
+  bool CanBeUsed(const sgd_attr_t& attr) const override {
     return platform::MayIUse(platform::avx) &&
            attr.grad_width % YMM_FLOAT_BLOCK == 0;
   }
diff --git a/paddle/fluid/operators/jit/gen/vbroadcast.cc b/paddle/fluid/operators/jit/gen/vbroadcast.cc
index 3f9fbdbd82..66a8d75fd4 100644
--- a/paddle/fluid/operators/jit/gen/vbroadcast.cc
+++ b/paddle/fluid/operators/jit/gen/vbroadcast.cc
@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
 
 class VBroadcastCreator : public JitCodeCreator<int64_t> {
  public:
-  bool UseMe(const int64_t& w) const override {
+  bool CanBeUsed(const int64_t& w) const override {
     return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
   }
   size_t CodeSize(const int64_t& w) const override {
diff --git a/paddle/fluid/operators/jit/gen_base.cc b/paddle/fluid/operators/jit/gen_base.cc
index f3603875ad..4c49eff49e 100644
--- a/paddle/fluid/operators/jit/gen_base.cc
+++ b/paddle/fluid/operators/jit/gen_base.cc
@@ -31,7 +31,7 @@ namespace paddle {
 namespace operators {
 namespace jit {
 
-// refer do not need useme, it would be the last one.
+// refer do not need CanBeUsed, it would be the last one.
 void GenBase::dumpCode(const unsigned char* code) const {
   if (code) {
     static int counter = 0;
diff --git a/paddle/fluid/operators/jit/gen_base.h b/paddle/fluid/operators/jit/gen_base.h
index a7c7a35a7e..033c603c07 100644
--- a/paddle/fluid/operators/jit/gen_base.h
+++ b/paddle/fluid/operators/jit/gen_base.h
@@ -31,9 +31,10 @@ class GenBase : public Kernel {
   virtual ~GenBase() = default;
   virtual std::string name() const = 0;
   virtual size_t getSize() const = 0;
-  virtual const unsigned char* getCodeInternal() = 0;
+  virtual const unsigned char* getCodeInternal() const = 0;
+  const char* ImplType() const override { return "JitCode"; }
   template <typename Func>
-  Func getCode() {
+  Func getCode() const {
     const unsigned char* code = this->getCodeInternal();
     if (FLAGS_dump_jitcode) {
       this->dumpCode(code);
@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
   virtual ~JitCodeCreator() = default;
 
   // condition when this jit code can be used.
-  virtual bool UseMe(const Attr& attr) const = 0;
+  virtual bool CanBeUsed(const Attr& attr) const = 0;
 
   // estimate this code size
   virtual size_t CodeSize(const Attr& attr) const = 0;
diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h
index d85c719c1c..1ac5318d46 100644
--- a/paddle/fluid/operators/jit/helper.h
+++ b/paddle/fluid/operators/jit/helper.h
@@ -16,6 +16,8 @@
 
 #include <iostream>
 #include <string>
+#include <unordered_map>
+#include <utility>  // for std::move
 #include <vector>
 #include "paddle/fluid/operators/jit/gen_base.h"
 #include "paddle/fluid/operators/jit/kernel_base.h"
@@ -27,35 +29,34 @@ namespace paddle {
 namespace operators {
 namespace jit {
 
-template <KernelType KT, typename KernelTuples, typename PlaceType>
+template <typename KernelTuple, typename PlaceType>
 inline typename std::enable_if<
-    std::is_same<typename KernelTuples::data_type, float>::value &&
+    std::is_same<typename KernelTuple::data_type, float>::value &&
         std::is_same<PlaceType, platform::CPUPlace>::value,
-    typename KernelTuples::func_type>::type
-GetJitCode(const typename KernelTuples::attr_type& attr) {
-  using Func = typename KernelTuples::func_type;
-  using Attr = typename KernelTuples::attr_type;
-  size_t key = JitCodeKey<Attr>(attr);
-  auto& codes = JitCodePool<KT>().Instance();
+    const Kernel*>::type
+GetJitCode(const typename KernelTuple::attr_type& attr) {
+  using Attr = typename KernelTuple::attr_type;
+  int64_t key = JitCodeKey<Attr>(attr);
+  auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
   if (codes.Has(key)) {
-    return codes.AllKernels().at(key)->template getCode<Func>();
+    return codes.AllKernels().at(key).get();
   }
 
   // creator is not related with attr, so can use KernelKey as key
-  KernelKey kkey(KT, PlaceType());
+  KernelKey kkey(KernelTuple::kernel_type, PlaceType());
   // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
-  auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
+  auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
   auto iter = creator_map.find(kkey);
   if (iter != creator_map.end()) {
     auto& creators = iter->second;
     for (auto& cur : creators) {
       auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
-      if (i && i->UseMe(attr)) {
+      if (i && i->CanBeUsed(attr)) {
         auto p = i->CreateJitCode(attr);
         if (p) {
-          auto f = p->template getCode<Func>();
+          auto res = p.get();
           codes.Insert(key, std::move(p));
-          return f;
+          return res;
         }
       }
     }
@@ -63,87 +64,153 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
   return nullptr;
 }
 
-template <KernelType KT, typename KernelTuples, typename PlaceType>
+template <typename KernelTuple, typename PlaceType>
 inline typename std::enable_if<
-    !std::is_same<typename KernelTuples::data_type, float>::value ||
+    !std::is_same<typename KernelTuple::data_type, float>::value ||
         !std::is_same<PlaceType, platform::CPUPlace>::value,
-    typename KernelTuples::func_type>::type
-GetJitCode(const typename KernelTuples::attr_type& attr) {
+    const Kernel*>::type
+GetJitCode(const typename KernelTuple::attr_type& attr) {
   return nullptr;
 }
 
 // Refer code do not related with attr, which is just for cast
 // Refer is always on CPUPlace
-template <KernelType KT, typename KernelTuples>
-inline typename KernelTuples::func_type GetRefer() {
-  auto& ref_pool = ReferKernelPool().Instance().AllKernels();
-  KernelKey kkey(KT, platform::CPUPlace());
+template <typename KernelTuple>
+inline const Kernel* GetReferKernel() {
+  auto& ref_pool = ReferKernelPool::Instance().AllKernels();
+  KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
   auto ref_iter = ref_pool.find(kkey);
   PADDLE_ENFORCE(ref_iter != ref_pool.end(),
                  "Every Kernel should have reference function.");
   auto& ref_impls = ref_iter->second;
   for (auto& impl : ref_impls) {
-    auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
+    auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
     if (i) {
-      return i->GetFunc();
+      return i;
     }
   }
   return nullptr;
 }
 
-template <KernelType KT, typename KernelTuples,
-          typename PlaceType = platform::CPUPlace>
-typename KernelTuples::func_type Get(
-    const typename KernelTuples::attr_type& attr) {
-  auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
-  if (jitfunc) {
-    return jitfunc;
+template <typename KernelTuple>
+inline typename KernelTuple::func_type GetReferFunc() {
+  auto ker = GetReferKernel<KernelTuple>();
+  auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
+  PADDLE_ENFORCE(p, "The Refer kernel should exsit");
+  return p->GetFunc();
+}
+
+// Return all Kernels that can be used
+template <typename KernelTuple, typename PlaceType>
+std::vector<const Kernel*> GetAllCandidateKernels(
+    const typename KernelTuple::attr_type& attr) {
+  // the search order shoudl be jitcode > more > refer
+  std::vector<const Kernel*> res;
+  auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
+  if (jitker) {
+    res.emplace_back(jitker);
   }
 
-  // pool: (KernelKey(type, place), vector<KernelPtr>)
-  KernelKey kkey(KT, PlaceType());
-  auto& pool = KernelPool().Instance().AllKernels();
+  // more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
+  KernelKey kkey(KernelTuple::kernel_type, PlaceType());
+  auto& pool = KernelPool::Instance().AllKernels();
   auto iter = pool.find(kkey);
   if (iter != pool.end()) {
     auto& impls = iter->second;
     for (auto& impl : impls) {
-      auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
-      if (i && i->UseMe(attr)) {
-        return i->GetFunc();
+      auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
+      if (i && i->CanBeUsed(attr)) {
+        res.emplace_back(i);
       }
     }
   }
 
   // The last implementation should be reference function on CPUPlace.
-  return GetRefer<KT, KernelTuples>();
+  auto ref = GetReferKernel<KernelTuple>();
+  PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
+  res.emplace_back(ref);
+  return res;
+}
+
+template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
+std::vector<std::pair<std::string, typename KernelTuple::func_type>>
+GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
+  using Func = typename KernelTuple::func_type;
+  auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
+  std::vector<std::pair<std::string, Func>> res;
+  for (auto k : kers) {
+    std::string name = k->ImplType();
+    if (name == "JitCode") {
+      auto i = dynamic_cast<const GenBase*>(k);
+      PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
+      res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
+    } else {
+      auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
+      PADDLE_ENFORCE(i, "kernel cast can not fail.");
+      res.emplace_back(std::make_pair(name, i->GetFunc()));
+    }
+  }
+  return res;
+}
+
+template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
+std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
+    const typename KernelTuple::attr_type& attr) {
+  auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
+  std::vector<typename KernelTuple::func_type> res;
+  for (auto& i : funcs) {
+    res.emplace_back(i.second);
+  }
+  return res;
+}
+
+template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
+typename KernelTuple::func_type GetDefaultBestFunc(
+    const typename KernelTuple::attr_type& attr) {
+  auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
+  PADDLE_ENFORCE_GE(funcs.size(), 1UL);
+  // Here could do some runtime benchmark of this attr and return the best one.
+  // But yet just get the first one as the default best one,
+  // which is searched in order and tuned by offline.
+  return funcs[0];
 }
 
-template <KernelType KT, typename KernelTuples, typename PlaceType>
+template <typename KernelTuple, typename PlaceType>
 class KernelFuncs {
  public:
   KernelFuncs() = default;
   static KernelFuncs& Cache() {
-    static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
+    static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
     return g_func_cache;
   }
 
-  bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
-
-  void Insert(int key, typename KernelTuples::func_type func) {
-    funcs_.emplace(key, func);
-  }
-
-  typename KernelTuples::func_type At(int key) {
+  // the exposed interface to use
+  typename KernelTuple::func_type At(
+      const typename KernelTuple::attr_type& attr) {
+    // Maybe here is not good enough, not all kernels should have jitcode
+    int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
     if (Has(key)) {
       return funcs_.at(key);
     }
-    auto func = Get<KT, KernelTuples, PlaceType>(key);
+    // If do not have this attr in cache then get the default best
+    auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
     Insert(key, func);
     return func;
   }
 
+  typename KernelTuple::func_type operator[](
+      const typename KernelTuple::attr_type& attr) {
+    return At(attr);
+  }
+
+ protected:
+  bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
+  void Insert(int64_t key, typename KernelTuple::func_type func) {
+    funcs_.emplace(key, func);
+  }
+
  private:
-  std::unordered_map<int, typename KernelTuples::func_type> funcs_;
+  std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
   DISABLE_COPY_AND_ASSIGN(KernelFuncs);
 };
 
diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h
index 96e162a21b..bd34d7dfc7 100644
--- a/paddle/fluid/operators/jit/kernel_base.h
+++ b/paddle/fluid/operators/jit/kernel_base.h
@@ -62,26 +62,55 @@ typedef enum {
   kSqrt,
 } SeqPoolType;
 
+// x, y, z, n
 template <typename T>
-struct XYZNTuples {
+struct XYZNTuple {
   typedef T data_type;
   typedef int attr_type;
   typedef void (*func_type)(const T*, const T*, T*, int);
 };
 
+// a, x, y, n
 template <typename T>
-struct AXYNTuples : public XYZNTuples<T> {};
+struct AXYNTuple : public XYZNTuple<T> {};
 
+// x, y, n
 template <typename T>
-struct XYNTuples {
+struct XYNTuple {
   typedef T data_type;
   typedef int attr_type;
   typedef void (*func_type)(const T*, T*, int);
 };
 
-// x, return and int
+// x, returned value, n
 template <typename T>
-struct XRNTuples : public XYNTuples<T> {};
+struct XRNTuple : public XYNTuple<T> {};
+
+#define DECLARE_KERNELTUPLE(kernel_tuple, type)        \
+  template <typename T>                                \
+  struct type##Tuple : public kernel_tuple<T> {        \
+    static constexpr KernelType kernel_type = k##type; \
+  }
+
+// Tuple should be corresponding to the KernelType
+DECLARE_KERNELTUPLE(XYZNTuple, VMul);
+DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
+DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
+DECLARE_KERNELTUPLE(XYZNTuple, VSub);
+
+DECLARE_KERNELTUPLE(AXYNTuple, VScal);
+DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
+
+DECLARE_KERNELTUPLE(XYNTuple, VRelu);
+DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
+DECLARE_KERNELTUPLE(XYNTuple, VSquare);
+DECLARE_KERNELTUPLE(XYNTuple, VExp);
+DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
+DECLARE_KERNELTUPLE(XYNTuple, VTanh);
+DECLARE_KERNELTUPLE(XYNTuple, VCopy);
+
+DECLARE_KERNELTUPLE(XRNTuple, HMax);
+DECLARE_KERNELTUPLE(XRNTuple, HSum);
 
 typedef struct {
   void* gates;  // gates: x_ch, x_ih, x_fh, x_oh
@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
 typedef struct lstm_attr_s lstm_attr_t;
 
 template <typename T>
-struct LSTMTuples {
+struct LSTMTuple {
   typedef T data_type;
   typedef lstm_attr_t attr_type;
   typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
 };
 
 template <typename T>
-struct GRUTuples {
+struct GRUTuple {
   typedef T data_type;
   typedef gru_attr_t attr_type;
   typedef void (*func_type)(gru_t*, const gru_attr_t*);
 };
 
+DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
+DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
+
+DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
+DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
+DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
+
+#undef DECLARE_KERNELTUPLE
+
 template <typename T>
-struct VBroadcastTuples {
+struct VBroadcastTuple {
+  static constexpr KernelType kernel_type = kVBroadcast;
   typedef T data_type;
   typedef int64_t attr_type;
   typedef void (*func_type)(const T*, T*, int64_t, int64_t);
@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
 } seq_pool_attr_t;
 
 template <typename T>
-struct SeqPoolTuples {
+struct SeqPoolTuple {
+  static constexpr KernelType kernel_type = kSeqPool;
   typedef T data_type;
   typedef seq_pool_attr_t attr_type;
   typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
 } emb_seq_pool_attr_t;
 
 template <typename T>
-struct EmbSeqPoolTuples {
+struct EmbSeqPoolTuple {
+  static constexpr KernelType kernel_type = kEmbSeqPool;
   typedef T data_type;
   typedef emb_seq_pool_attr_t attr_type;
   typedef void (*func_type)(const T*, const int64_t*, T*,
@@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
 } sgd_attr_t;
 
 template <typename T>
-struct SgdTuples {
+struct SgdTuple {
+  static constexpr KernelType kernel_type = kSgd;
   typedef T data_type;
   typedef sgd_attr_t attr_type;
   typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
@@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
 } matmul_attr_t;
 
 template <typename T>
-struct MatMulTuples {
+struct MatMulTuple {
+  static constexpr KernelType kernel_type = kMatMul;
   typedef T data_type;
   typedef matmul_attr_t attr_type;
   typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
 };
 
 template <typename T>
-struct CRFDecodingTuples {
+struct CRFDecodingTuple {
+  static constexpr KernelType kernel_type = kCRFDecoding;
   typedef T data_type;
   typedef int attr_type;
   typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
 };
 
 template <typename T>
-struct LayerNormTuples {
+struct LayerNormTuple {
+  static constexpr KernelType kernel_type = kLayerNorm;
   typedef T data_type;
   typedef int attr_type;
   typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
@@ -236,7 +281,8 @@ struct LayerNormTuples {
 };
 
 template <typename T>
-struct SoftmaxTuples {
+struct SoftmaxTuple {
+  static constexpr KernelType kernel_type = kSoftmax;
   typedef T data_type;
   typedef int attr_type;
   typedef void (*func_type)(const T*, T*, int, int);
@@ -244,7 +290,8 @@ struct SoftmaxTuples {
 
 // nChw16c = nChw16c .* NC
 template <typename T>
-struct NCHW16CMulNCTuples {
+struct NCHW16CMulNCTuple {
+  static constexpr KernelType kernel_type = kNCHW16CMulNC;
   typedef T data_type;
   typedef int attr_type;
   typedef void (*func_type)(const T*, const T*, T*, int, int);
@@ -255,28 +302,29 @@ class Kernel {
  public:
   Kernel() = default;
   virtual ~Kernel() = default;
+  virtual const char* ImplType() const = 0;
   DISABLE_COPY_AND_ASSIGN(Kernel);
 };
 
-template <typename KernelTuples>
+template <typename KernelTuple>
 class KernelMore : public Kernel {
  public:
-  using T = typename KernelTuples::data_type;
-  using Func = typename KernelTuples::func_type;
-  using Attr = typename KernelTuples::attr_type;
+  using T = typename KernelTuple::data_type;
+  using Func = typename KernelTuple::func_type;
+  using Attr = typename KernelTuple::attr_type;
   virtual Func GetFunc() const { return func; }
-  virtual bool UseMe(const Attr& attr) const = 0;
-  virtual const char* ImplType() const = 0;
+  // specify this kernel can be used, means it should not fail if use it.
+  virtual bool CanBeUsed(const Attr& attr) const = 0;
 
  protected:
   Func func{nullptr};
 };
 
-template <typename KernelTuples>
-class ReferKernel : public KernelMore<KernelTuples> {
+template <typename KernelTuple>
+class ReferKernel : public KernelMore<KernelTuple> {
  public:
   // Refer code can always be used
-  bool UseMe(const typename KernelTuples::attr_type& attr) const override {
+  bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
     return true;
   }
   const char* ImplType() const override { return "Refer"; }
diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc
index 1c2fddcae7..1ad220b397 100644
--- a/paddle/fluid/operators/jit/kernel_key.cc
+++ b/paddle/fluid/operators/jit/kernel_key.cc
@@ -13,6 +13,7 @@
  * limitations under the License. */
 
 #include "paddle/fluid/operators/jit/kernel_key.h"
+#include <xxhash.h>  // XXH64: 13.8 GB/s
 #include "paddle/fluid/platform/enforce.h"
 
 namespace paddle {
@@ -20,71 +21,46 @@ namespace operators {
 namespace jit {
 
 template <>
-size_t JitCodeKey<int>(const int& d) {
+int64_t JitCodeKey<int>(const int& d) {
   return d;
 }
 
 template <>
-size_t JitCodeKey<int64_t>(const int64_t& d) {
+int64_t JitCodeKey<int64_t>(const int64_t& d) {
   return d;
 }
 
-// TODO(TJ): refine and benchmark JitCodeKey generatation
-constexpr int act_type_shift = 3;  // suppot 2^3 act types
-static inline int act_type_convert(KernelType type) {
-  if (type == kVIdentity) {
-    return 0;
-  } else if (type == kVExp) {
-    return 1;
-  } else if (type == kVRelu) {
-    return 2;
-  } else if (type == kVSigmoid) {
-    return 3;
-  } else if (type == kVTanh) {
-    return 4;
-  }
-  PADDLE_THROW("Unsupported act type %d", type);
-  return 0;
-}
-
 template <>
-size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
-  size_t key = attr.d;
-  int gate_key = act_type_convert(attr.act_gate) << 1;
-  int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
-  int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
-  return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
-         attr.use_peephole;
+int64_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
+  return XXH64(&attr, sizeof(gru_attr_t), 0);
 }
 
 template <>
-size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
-  size_t key = attr.d;
-  return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
-         (act_type_convert(attr.act_cand) << act_type_shift);
+int64_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
+  int keys[5] = {
+      attr.d, static_cast<int>(attr.act_gate), static_cast<int>(attr.act_cand),
+      static_cast<int>(attr.act_cell), static_cast<int>(attr.use_peephole)};
+  return XXH64(keys, sizeof(int) * 5, 0);
 }
 
 template <>
-size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
-  size_t key = attr.w;
-  constexpr int pool_type_shift = 3;
-  return (key << pool_type_shift) + static_cast<int>(attr.type);
+int64_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
+  int keys[2] = {attr.w, static_cast<int>(attr.type)};
+  return XXH64(keys, sizeof(int) * 2, 0);
 }
 
 template <>
-size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
-  size_t key = attr.m;
-  constexpr int shift = 21;
-  return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
+int64_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
+  return XXH64(&attr, sizeof(int) * 3, 0);  // m, n, k
 }
 
 template <>
-size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
+int64_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
   return attr.table_width;
 }
 
 template <>
-size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
+int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
   return attr.grad_width;
 }
 
diff --git a/paddle/fluid/operators/jit/kernel_key.h b/paddle/fluid/operators/jit/kernel_key.h
index 611a0210d6..b2cf92f23e 100644
--- a/paddle/fluid/operators/jit/kernel_key.h
+++ b/paddle/fluid/operators/jit/kernel_key.h
@@ -46,7 +46,7 @@ struct KernelKey {
 
 // Every JitCode should have a method to get the key from attribution
 template <typename Attr>
-size_t JitCodeKey(const Attr& attr);
+int64_t JitCodeKey(const Attr& attr);
 
 }  // namespace jit
 }  // namespace operators
diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h
index 3e15242af2..04710a54ac 100644
--- a/paddle/fluid/operators/jit/kernel_pool.h
+++ b/paddle/fluid/operators/jit/kernel_pool.h
@@ -17,6 +17,7 @@
 #include <memory>  // for unique_ptr
 #include <string>
 #include <unordered_map>
+#include <utility>  // for move
 #include <vector>
 #include "paddle/fluid/operators/jit/gen_base.h"
 #include "paddle/fluid/operators/jit/kernel_base.h"
@@ -30,7 +31,7 @@ namespace jit {
 template <KernelType KT>
 class JitCodePool {
   typedef std::unique_ptr<GenBase> GenBasePtr;
-  typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap;
+  typedef std::unordered_map<int64_t, GenBasePtr> JitCodeMap;
 
  public:
   JitCodePool() = default;
@@ -41,9 +42,9 @@ class JitCodePool {
 
   const JitCodeMap& AllKernels() { return codes_; }
 
-  bool Has(size_t key) const { return codes_.find(key) != codes_.end(); }
+  bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); }
 
-  void Insert(size_t key, GenBasePtr value) {
+  void Insert(int64_t key, GenBasePtr value) {
     codes_.emplace(key, std::move(value));
   }
 
diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
index 16c91f8246..1254d00189 100644
--- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
+++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
   }
 }
 
-bool CRFDecodingKernel::UseMe(const int& d) const {
+bool CRFDecodingKernel::CanBeUsed(const int& d) const {
 #ifdef __AVX512F__
   constexpr int block = ZMM_FLOAT_BLOCK;
 #else
diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
index 24179d90dd..49b1a1fea4 100644
--- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
+++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
@@ -26,11 +26,11 @@ namespace intrinsic {
 void CRFDecoding(const int seq_len, const float* x, const float* w,
                  float* alpha, int* track, int tag_num);
 
-class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> {
+class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
  public:
   CRFDecodingKernel() { this->func = CRFDecoding; }
-  bool UseMe(
-      const typename CRFDecodingTuples<float>::attr_type&) const override;
+  bool CanBeUsed(
+      const typename CRFDecodingTuple<float>::attr_type&) const override;
   const char* ImplType() const override { return "Intrinsic"; }
 };
 
diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
index e9b6e401c6..a4e3246f10 100644
--- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
+++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
   }
 }
 
-bool LayerNormKernel::UseMe(const int& d) const {
+bool LayerNormKernel::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
 }
 
diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
index 89da2940f4..7b9f676050 100644
--- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
+++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
                const float* scale, const float* bias, int height,
                const float epsilon, int right);
 
-class LayerNormKernel : public KernelMore<LayerNormTuples<float>> {
+class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
  public:
   LayerNormKernel() { this->func = LayerNorm; }
-  bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override;
+  bool CanBeUsed(
+      const typename LayerNormTuple<float>::attr_type&) const override;
   const char* ImplType() const override { return "Intrinsic"; }
 };
 
diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc
index 0036d1c238..6e709a16d2 100644
--- a/paddle/fluid/operators/jit/more/mix/mix.cc
+++ b/paddle/fluid/operators/jit/more/mix/mix.cc
@@ -23,6 +23,8 @@ namespace jit {
 namespace more {
 namespace mix {
 
+using CPUPlace = platform::CPUPlace;
+
 void VSigmoid(const T* x, T* y, int n) {
   const float min = SIGMOID_THRESHOLD_MIN;
   const float max = SIGMOID_THRESHOLD_MAX;
@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
     y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
     y[i] = static_cast<T>(0) - y[i];
   }
-  auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
+  auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
   compute(y, y, n);
   for (int i = 0; i < n; ++i) {
     y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
 
 void VTanh(const T* x, T* y, int n) {
   const T a = 2, b = -1;
-  auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
-  auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
-  auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
+  auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
+  auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
+  auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
   compute_scal(&a, x, y, n);
   compute_sigmoid(y, y, n);
   compute_scal(&a, y, y, n);
@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
 }
 
 void Softmax(const T* x, T* y, int n, int bs) {
-  auto compute_hmax =
-      KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
-  auto compute_hsum =
-      KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
-  auto compute_vscal =
-      KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
+  auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
+  auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
+  auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
   auto compute_vaddbias =
-      KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
-  auto compute_vexp =
-      KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
+      KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
+  auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
 
   for (int i = 0; i < bs; ++i) {
     T scalar;
@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
 
 void (*getActFunc(KernelType type, int d))(const T*, T*, int) {  // NOLINT
   if (type == kVSigmoid) {
-    return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
+    return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
   } else if (type == kVRelu) {
-    return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
+    return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
   } else if (type == kVTanh) {
-    return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
+    return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
   } else if (type == kVIdentity) {
-    return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
+    return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
   }
   PADDLE_THROW("Not support type: %s", type);
   return nullptr;
@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
   const int d = attr->d;
   const int d2 = d * 2;
   const int d3 = d * 3;
-  auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
-  auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
-  auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
+  auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
+  auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
+  auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
   auto act_gate_d = getActFunc(attr->act_gate, d);
   auto act_gate_d2 = getActFunc(attr->act_gate, d2);
   auto act_gate_d3 = getActFunc(attr->act_gate, d3);
@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
   int d = attr->d;
   int d2 = d * 2;
   int d3 = d * 3;
-  auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
-  auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
+  auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
+  auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
   auto act_gate_d = getActFunc(attr->act_gate, d);
   auto act_cand_d = getActFunc(attr->act_cand, d);
   auto act_cell_d = getActFunc(attr->act_cell, d);
@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
   int d2 = d * 2;
   auto act_gate = getActFunc(attr->act_gate, d);
   auto act_cand = getActFunc(attr->act_cand, d);
-  auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
+  auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
   act_gate(gates, gates, d);
   act_cand(gates + d2, gates + d2, d);
   vmul_d(gates, gates + d2, ht, d);
@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
   T* ht = reinterpret_cast<T*>(step->ht);
   const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
   auto act_gate = getActFunc(attr->act_gate, attr->d);
-  auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
+  auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
   act_gate(gates + attr->d, gates + attr->d, attr->d);
   vmul_d(ht_1, gates + attr->d, ht, attr->d);
 }
@@ -206,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
 }
 
 // TODO(TJ): tuning me
-bool VSigmoidKernel::UseMe(const int& d) const { return true; }
+bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
 
-bool VTanhKernel::UseMe(const int& d) const { return true; }
+bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
 
-bool SoftmaxKernel::UseMe(const int& d) const { return true; }
+bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
 
-bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
+bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
 
-bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
+bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
 
-bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
+bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
 
-bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
+bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
 
-bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
+bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
 
 }  // namespace mix
 }  // namespace more
@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
 
 namespace mix = paddle::operators::jit::more::mix;
 
-#define REGISTER_MORE_KERNEL(key, func) \
-  REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
-
-REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
-REGISTER_MORE_KERNEL(kVTanh, VTanh);
-REGISTER_MORE_KERNEL(kSoftmax, Softmax);
-REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
-REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
-REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
-REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
-REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
+#define REGISTER_MORE_KERNEL(func) \
+  REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
+
+REGISTER_MORE_KERNEL(VSigmoid);
+REGISTER_MORE_KERNEL(VTanh);
+REGISTER_MORE_KERNEL(Softmax);
+REGISTER_MORE_KERNEL(LSTMCtHt);
+REGISTER_MORE_KERNEL(LSTMC1H1);
+REGISTER_MORE_KERNEL(GRUH1);
+REGISTER_MORE_KERNEL(GRUHtPart1);
+REGISTER_MORE_KERNEL(GRUHtPart2);
 
 #undef REGISTER_MORE_KERNEL
diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h
index d64af19219..994d485909 100644
--- a/paddle/fluid/operators/jit/more/mix/mix.h
+++ b/paddle/fluid/operators/jit/more/mix/mix.h
@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
 void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
 void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
 
-#define DECLARE_MORE_KERNEL(name, tuples)                            \
-  class name##Kernel : public KernelMore<tuples<T>> {                \
-   public:                                                           \
-    name##Kernel() { this->func = name; }                            \
-    bool UseMe(const typename tuples<T>::attr_type&) const override; \
-    const char* ImplType() const override { return "Mixed"; }        \
+#define DECLARE_MORE_KERNEL(name)                                             \
+  class name##Kernel : public KernelMore<name##Tuple<T>> {                    \
+   public:                                                                    \
+    name##Kernel() { this->func = name; }                                     \
+    bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
+    const char* ImplType() const override { return "Mixed"; }                 \
   }
 
 // XYN
-DECLARE_MORE_KERNEL(VSigmoid, XYNTuples);
-DECLARE_MORE_KERNEL(VTanh, XYNTuples);
+DECLARE_MORE_KERNEL(VSigmoid);
+DECLARE_MORE_KERNEL(VTanh);
 
 // XRN
-DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples);
+DECLARE_MORE_KERNEL(Softmax);
 
-DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples);
-DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples);
+DECLARE_MORE_KERNEL(LSTMCtHt);
+DECLARE_MORE_KERNEL(LSTMC1H1);
 
-DECLARE_MORE_KERNEL(GRUH1, GRUTuples);
-DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples);
-DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples);
+DECLARE_MORE_KERNEL(GRUH1);
+DECLARE_MORE_KERNEL(GRUHtPart1);
+DECLARE_MORE_KERNEL(GRUHtPart2);
 
 #undef DECLARE_MORE_KERNEL
 
diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc
index 4f51353bce..4f600b3814 100644
--- a/paddle/fluid/operators/jit/more/mkl/mkl.cc
+++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc
@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) {
 
 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
 template <>
-bool VMulKernel<float>::UseMe(const int& d) const {
+bool VMulKernel<float>::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx512f) && d > 512;
 }
 
 template <>
-bool VAddKernel<float>::UseMe(const int& d) const {
+bool VAddKernel<float>::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx) && d > 512;
 }
 
 template <>
-bool VScalKernel<float>::UseMe(const int& d) const {
+bool VScalKernel<float>::CanBeUsed(const int& d) const {
   return platform::MayIUse(platform::avx512f) && d > 512;
 }
 
 template <>
-bool VExpKernel<float>::UseMe(const int& d) const {
+bool VExpKernel<float>::CanBeUsed(const int& d) const {
   return d > 7;
 }
 
 template <>
-bool VSquareKernel<float>::UseMe(const int& d) const {
+bool VSquareKernel<float>::CanBeUsed(const int& d) const {
   return d > 7;
 }
 
 template <>
-bool VCopyKernel<float>::UseMe(const int& d) const {
+bool VCopyKernel<float>::CanBeUsed(const int& d) const {
   return d > 15;
 }
 
 template <>
-bool VBroadcastKernel<float>::UseMe(const int64_t& d) const {
+bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
   return d > 127;
 }
 
 template <>
-bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const {
+bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
   return true;
 }
 
 template <>
-bool VSigmoidKernel<float>::UseMe(const int& d) const {
+bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
   return d > 7;
 }
 
 template <>
-bool VTanhKernel<float>::UseMe(const int& d) const {
+bool VTanhKernel<float>::CanBeUsed(const int& d) const {
   return d > 7;
 }
 
 template <>
-bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
+bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
+bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
+bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
+bool EmbSeqPoolKernel<double>::CanBeUsed(
+    const emb_seq_pool_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const {
+bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const {
+bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
+bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
   return platform::MayIUse(platform::avx);
 }
 
 template <>
-bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const {
+bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
   return true;
 }
 
 template <>
-bool SoftmaxKernel<float>::UseMe(const int& d) const {
+bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
   // tuned on avx2
   return platform::MayIUse(platform::avx) && d < 60;
 }
 
-#define AWALYS_USE_ME_WITH_DOUBLE(func)                  \
-  template <>                                            \
-  bool func##Kernel<double>::UseMe(const int& d) const { \
-    return true;                                         \
+#define AWALYS_USE_ME_WITH_DOUBLE(func)                      \
+  template <>                                                \
+  bool func##Kernel<double>::CanBeUsed(const int& d) const { \
+    return true;                                             \
   }
 
 AWALYS_USE_ME_WITH_DOUBLE(VMul);
@@ -250,23 +251,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
 
 namespace mkl = paddle::operators::jit::more::mkl;
 
-#define REGISTER_MKL_KERNEL(key, func)                        \
-  REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
+#define REGISTER_MKL_KERNEL(func)                                 \
+  REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
                           mkl::func##Kernel<double>)
 
-REGISTER_MKL_KERNEL(kMatMul, MatMul);
-REGISTER_MKL_KERNEL(kVMul, VMul);
-REGISTER_MKL_KERNEL(kVAdd, VAdd);
-REGISTER_MKL_KERNEL(kVScal, VScal);
-REGISTER_MKL_KERNEL(kVExp, VExp);
-REGISTER_MKL_KERNEL(kVSquare, VSquare);
-REGISTER_MKL_KERNEL(kVCopy, VCopy);
-REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast);
-REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
-REGISTER_MKL_KERNEL(kVTanh, VTanh);
-REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
-REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
-REGISTER_MKL_KERNEL(kSoftmax, Softmax);
-REGISTER_MKL_KERNEL(kSgd, Sgd);
+REGISTER_MKL_KERNEL(MatMul);
+REGISTER_MKL_KERNEL(VMul);
+REGISTER_MKL_KERNEL(VAdd);
+REGISTER_MKL_KERNEL(VScal);
+REGISTER_MKL_KERNEL(VExp);
+REGISTER_MKL_KERNEL(VSquare);
+REGISTER_MKL_KERNEL(VCopy);
+REGISTER_MKL_KERNEL(VBroadcast);
+REGISTER_MKL_KERNEL(VSigmoid);
+REGISTER_MKL_KERNEL(VTanh);
+REGISTER_MKL_KERNEL(SeqPool);
+REGISTER_MKL_KERNEL(EmbSeqPool);
+REGISTER_MKL_KERNEL(Softmax);
+REGISTER_MKL_KERNEL(Sgd);
 
 #undef REGISTER_MKL_KERNEL
diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h
index db2d6faed4..f51dca654c 100644
--- a/paddle/fluid/operators/jit/more/mkl/mkl.h
+++ b/paddle/fluid/operators/jit/more/mkl/mkl.h
@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
   }
 }
 
-#define DECLARE_MKL_KERNEL(name, tuples)                             \
-  template <typename T>                                              \
-  class name##Kernel : public KernelMore<tuples<T>> {                \
-   public:                                                           \
-    name##Kernel() { this->func = name<T>; }                         \
-    bool UseMe(const typename tuples<T>::attr_type&) const override; \
-    const char* ImplType() const override { return "MKL"; }          \
+#define DECLARE_MKL_KERNEL(name)                                              \
+  template <typename T>                                                       \
+  class name##Kernel : public KernelMore<name##Tuple<T>> {                    \
+   public:                                                                    \
+    name##Kernel() { this->func = name<T>; }                                  \
+    bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
+    const char* ImplType() const override { return "MKL"; }                   \
   }
 
 // ABCMNK
-DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
+DECLARE_MKL_KERNEL(MatMul);
 
 // XYZN
-DECLARE_MKL_KERNEL(VMul, XYZNTuples);
-DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
+DECLARE_MKL_KERNEL(VMul);
+DECLARE_MKL_KERNEL(VAdd);
 
 // AXYN
-DECLARE_MKL_KERNEL(VScal, AXYNTuples);
+DECLARE_MKL_KERNEL(VScal);
 
 // XYN
-DECLARE_MKL_KERNEL(VExp, XYNTuples);
-DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
-DECLARE_MKL_KERNEL(VTanh, XYNTuples);
-DECLARE_MKL_KERNEL(VSquare, XYNTuples);
-DECLARE_MKL_KERNEL(VCopy, XYNTuples);
-
-DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
-
-DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
-
-DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
-
-DECLARE_MKL_KERNEL(Sgd, SgdTuples);
-
-DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
+DECLARE_MKL_KERNEL(VExp);
+DECLARE_MKL_KERNEL(VSigmoid);
+DECLARE_MKL_KERNEL(VTanh);
+DECLARE_MKL_KERNEL(VSquare);
+DECLARE_MKL_KERNEL(VCopy);
+
+// others
+DECLARE_MKL_KERNEL(SeqPool);
+DECLARE_MKL_KERNEL(EmbSeqPool);
+DECLARE_MKL_KERNEL(Softmax);
+DECLARE_MKL_KERNEL(Sgd);
+DECLARE_MKL_KERNEL(VBroadcast);
 
 #undef DECLARE_MKL_KERNEL
 
diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc
index c279d1b2ca..0d1c477090 100644
--- a/paddle/fluid/operators/jit/refer/refer.cc
+++ b/paddle/fluid/operators/jit/refer/refer.cc
@@ -17,51 +17,43 @@
 
 namespace refer = paddle::operators::jit::refer;
 
-#define REGISTER_REFER_KERNEL(key, func)                    \
-  REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
+#define REGISTER_REFER_KERNEL(func)                             \
+  REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
                            refer::func##Kernel<double>)
 
-REGISTER_REFER_KERNEL(kVMul, VMul);
-REGISTER_REFER_KERNEL(kVAdd, VAdd);
-REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu);
-REGISTER_REFER_KERNEL(kVSub, VSub);
-
-REGISTER_REFER_KERNEL(kVScal, VScal);
-REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
-
-REGISTER_REFER_KERNEL(kVRelu, VRelu);
-REGISTER_REFER_KERNEL(kVCopy, VCopy);
-REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
-REGISTER_REFER_KERNEL(kVSquare, VSquare);
-REGISTER_REFER_KERNEL(kVExp, VExp);
-REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
-REGISTER_REFER_KERNEL(kVTanh, VTanh);
-
-REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt);
-REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1);
-
-REGISTER_REFER_KERNEL(kGRUH1, GRUH1);
-REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1);
-REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2);
-
-REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding);
-REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
-
-REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
-
-REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
-
-REGISTER_REFER_KERNEL(kMatMul, MatMul);
-
-REGISTER_REFER_KERNEL(kHMax, HMax);
-REGISTER_REFER_KERNEL(kHSum, HSum);
-
-REGISTER_REFER_KERNEL(kSoftmax, Softmax);
-
-REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
-
-REGISTER_REFER_KERNEL(kSgd, Sgd);
-
-REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
+REGISTER_REFER_KERNEL(VMul);
+REGISTER_REFER_KERNEL(VAdd);
+REGISTER_REFER_KERNEL(VAddRelu);
+REGISTER_REFER_KERNEL(VSub);
+
+REGISTER_REFER_KERNEL(VScal);
+REGISTER_REFER_KERNEL(VAddBias);
+
+REGISTER_REFER_KERNEL(VRelu);
+REGISTER_REFER_KERNEL(VCopy);
+REGISTER_REFER_KERNEL(VIdentity);
+REGISTER_REFER_KERNEL(VSquare);
+REGISTER_REFER_KERNEL(VExp);
+REGISTER_REFER_KERNEL(VSigmoid);
+REGISTER_REFER_KERNEL(VTanh);
+
+REGISTER_REFER_KERNEL(LSTMCtHt);
+REGISTER_REFER_KERNEL(LSTMC1H1);
+
+REGISTER_REFER_KERNEL(GRUH1);
+REGISTER_REFER_KERNEL(GRUHtPart1);
+REGISTER_REFER_KERNEL(GRUHtPart2);
+
+REGISTER_REFER_KERNEL(CRFDecoding);
+REGISTER_REFER_KERNEL(LayerNorm);
+REGISTER_REFER_KERNEL(NCHW16CMulNC);
+REGISTER_REFER_KERNEL(SeqPool);
+REGISTER_REFER_KERNEL(MatMul);
+REGISTER_REFER_KERNEL(HMax);
+REGISTER_REFER_KERNEL(HSum);
+REGISTER_REFER_KERNEL(Softmax);
+REGISTER_REFER_KERNEL(EmbSeqPool);
+REGISTER_REFER_KERNEL(Sgd);
+REGISTER_REFER_KERNEL(VBroadcast);
 
 #undef REGISTER_REFER_KERNEL
diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h
index b3b2097828..cac705a484 100644
--- a/paddle/fluid/operators/jit/refer/refer.h
+++ b/paddle/fluid/operators/jit/refer/refer.h
@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
   }
 }
 
-#define DECLARE_REFER_KERNEL(name, tuples)             \
-  template <typename T>                                \
-  class name##Kernel : public ReferKernel<tuples<T>> { \
-   public:                                             \
-    name##Kernel() { this->func = name<T>; }           \
+#define DECLARE_REFER_KERNEL(name)                          \
+  template <typename T>                                     \
+  class name##Kernel : public ReferKernel<name##Tuple<T>> { \
+   public:                                                  \
+    name##Kernel() { this->func = name<T>; }                \
   }
 
 // const T* x, const T* y, T* z, int n
-DECLARE_REFER_KERNEL(VMul, XYZNTuples);
-DECLARE_REFER_KERNEL(VAdd, XYZNTuples);
-DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples);
-DECLARE_REFER_KERNEL(VSub, XYZNTuples);
+DECLARE_REFER_KERNEL(VMul);
+DECLARE_REFER_KERNEL(VAdd);
+DECLARE_REFER_KERNEL(VAddRelu);
+DECLARE_REFER_KERNEL(VSub);
 
 // const T* a, const T* x, T* y, int n
-DECLARE_REFER_KERNEL(VScal, AXYNTuples);
-DECLARE_REFER_KERNEL(VAddBias, AXYNTuples);
+DECLARE_REFER_KERNEL(VScal);
+DECLARE_REFER_KERNEL(VAddBias);
 
 // const T* x, T* y, int n
-DECLARE_REFER_KERNEL(VRelu, XYNTuples);
-DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
-DECLARE_REFER_KERNEL(VExp, XYNTuples);
-DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
-DECLARE_REFER_KERNEL(VTanh, XYNTuples);
-DECLARE_REFER_KERNEL(VSquare, XYNTuples);
-DECLARE_REFER_KERNEL(VCopy, XYNTuples);
+DECLARE_REFER_KERNEL(VRelu);
+DECLARE_REFER_KERNEL(VIdentity);
+DECLARE_REFER_KERNEL(VExp);
+DECLARE_REFER_KERNEL(VSigmoid);
+DECLARE_REFER_KERNEL(VTanh);
+DECLARE_REFER_KERNEL(VSquare);
+DECLARE_REFER_KERNEL(VCopy);
 
 // lstm_t*, const lstm_attr_t*
-DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
-DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
+DECLARE_REFER_KERNEL(LSTMCtHt);
+DECLARE_REFER_KERNEL(LSTMC1H1);
 
 // gru_t*, const gru_attr_t*
-DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
-DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
-DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
-
-DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
-DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
-
-DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
-
-DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
-
-DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
-
-DECLARE_REFER_KERNEL(HMax, XRNTuples);
-DECLARE_REFER_KERNEL(HSum, XRNTuples);
-
-DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
-
-DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
-
-DECLARE_REFER_KERNEL(Sgd, SgdTuples);
-
-DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
+DECLARE_REFER_KERNEL(GRUH1);
+DECLARE_REFER_KERNEL(GRUHtPart1);
+DECLARE_REFER_KERNEL(GRUHtPart2);
+
+DECLARE_REFER_KERNEL(HMax);
+DECLARE_REFER_KERNEL(HSum);
+
+// others
+DECLARE_REFER_KERNEL(CRFDecoding);
+DECLARE_REFER_KERNEL(LayerNorm);
+DECLARE_REFER_KERNEL(NCHW16CMulNC);
+DECLARE_REFER_KERNEL(SeqPool);
+DECLARE_REFER_KERNEL(MatMul);
+DECLARE_REFER_KERNEL(Softmax);
+DECLARE_REFER_KERNEL(EmbSeqPool);
+DECLARE_REFER_KERNEL(Sgd);
+DECLARE_REFER_KERNEL(VBroadcast);
 
 #undef DECLARE_REFER_KERNEL
 
diff --git a/paddle/fluid/operators/jit/registry.h b/paddle/fluid/operators/jit/registry.h
index cb32c48720..567a903236 100644
--- a/paddle/fluid/operators/jit/registry.h
+++ b/paddle/fluid/operators/jit/registry.h
@@ -17,6 +17,7 @@
 #include <memory>
 #include <tuple>
 #include <type_traits>
+#include <utility>  // for std::move
 #include "paddle/fluid/operators/jit/kernel_base.h"
 #include "paddle/fluid/operators/jit/kernel_pool.h"
 #include "paddle/fluid/platform/place.h"
@@ -49,8 +50,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
 
   void operator()(KernelType kt) const {
     KernelKey kkey(kt, PlaceType());
-    Pool().Instance().Insert(kkey,
-                             std::move(make_unique<const KERNEL_IMPL_TYPE>()));
+    Pool::Instance().Insert(kkey,
+                            std::move(make_unique<const KERNEL_IMPL_TYPE>()));
     constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
     JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
                               KernelImpls...>
diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc
index cdec14dc43..6c099a7a06 100644
--- a/paddle/fluid/operators/jit/test.cc
+++ b/paddle/fluid/operators/jit/test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #include <algorithm>
+#include <iostream>
 #include <random>
 #include <string>
 #include <vector>
@@ -64,413 +65,23 @@ std::vector<int> TestSizes() {
 namespace jit = paddle::operators::jit;
 using CPUPlace = paddle::platform::CPUPlace;
 
-template <typename KernelTuples, typename... Args>
-struct TestFuncWithRefer {
-  void operator()(const typename KernelTuples::func_type tgt, Args... args) {
-    LOG(FATAL) << "Should specify this function.";
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::XYZNTuples<T>, std::vector<T>, std::vector<T>,
-                         std::vector<T>> {
-  void operator()(const typename jit::XYZNTuples<T>::func_type tgt,
-                  const std::vector<T>& x, const std::vector<T>& y,
-                  const std::vector<T>& zref) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(zref.size(), x.size());
-    EXPECT_EQ(zref.size(), y.size());
-    const T* x_data = x.data();
-    const T* y_data = y.data();
-    const T* zref_data = zref.data();
-    const int d = zref.size();
-
-    std::vector<T> ztgt(d);
-    T* ztgt_data = ztgt.data();
-    // test normal
-    tgt(x_data, y_data, ztgt_data, d);
-    ExpectEQ<T>(ztgt_data, zref_data, d);
-    // test inplace x
-    std::copy(x.begin(), x.end(), ztgt.begin());
-    tgt(ztgt_data, y_data, ztgt_data, d);
-    ExpectEQ<T>(ztgt_data, zref_data, d);
-    // test inplace y
-    std::copy(y.begin(), y.end(), ztgt.begin());
-    tgt(x_data, ztgt_data, ztgt_data, d);
-    ExpectEQ<T>(ztgt_data, zref_data, d);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>,
-                         std::vector<T>> {
-  void operator()(const typename jit::AXYNTuples<T>::func_type tgt, const T a,
-                  const std::vector<T>& x, const std::vector<T>& yref) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(yref.size(), x.size());
-    const T* x_data = x.data();
-    const T* yref_data = yref.data();
-    const int d = yref.size();
-    std::vector<T> ytgt(d);
-    T* ytgt_data = ytgt.data();
-    // test normal
-    tgt(&a, x_data, ytgt_data, d);
-    ExpectEQ<T>(ytgt_data, yref_data, d);
-    // test inplace x
-    std::copy(x.begin(), x.end(), ytgt.begin());
-    tgt(&a, ytgt_data, ytgt_data, d);
-    ExpectEQ<T>(ytgt_data, yref_data, d);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::SoftmaxTuples<T>, std::vector<T>, std::vector<T>,
-                         int, int> {
-  void operator()(const typename jit::SoftmaxTuples<T>::func_type tgt,
-                  const std::vector<T>& x, const std::vector<T>& yref, int n,
-                  int bs) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(yref.size(), x.size());
-    EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
-    const T* x_data = x.data();
-    const T* yref_data = yref.data();
-    std::vector<T> ytgt(n * bs);
-    T* ytgt_data = ytgt.data();
-    // test normal
-    tgt(x_data, ytgt_data, n, bs);
-    ExpectEQ<T>(ytgt_data, yref_data, n * bs);
-    // test inplace x
-    std::copy(x.begin(), x.end(), ytgt.begin());
-    tgt(ytgt_data, ytgt_data, n, bs);
-    ExpectEQ<T>(ytgt_data, yref_data, n * bs);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> {
-  void operator()(const typename jit::XRNTuples<T>::func_type tgt,
-                  const std::vector<T>& x, const T ref_res) {
-    EXPECT_TRUE(tgt != nullptr);
-    T tgt_res;
-    tgt(x.data(), &tgt_res, x.size());
-    ExpectEQ<T>(&tgt_res, &ref_res, 1);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::VBroadcastTuples<T>, std::vector<T>,
-                         std::vector<T>, int64_t,
-                         typename jit::VBroadcastTuples<T>::attr_type> {
-  void operator()(const typename jit::VBroadcastTuples<T>::func_type tgt,
-                  const std::vector<T>& x, const std::vector<T>& yref,
-                  int64_t h,
-                  const typename jit::VBroadcastTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(x.size(), static_cast<size_t>(attr));
-    EXPECT_EQ(yref.size(), x.size() * h);
-    std::vector<T> y(yref.size());
-    const T* x_data = x.data();
-    const T* yref_data = yref.data();
-    T* y_data = y.data();
-    tgt(x_data, y_data, h, attr);
-    ExpectEQ<T>(y_data, yref_data, yref.size());
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
-  void operator()(const typename jit::XYNTuples<T>::func_type tgt,
-                  const std::vector<T>& x, const std::vector<T>& yref) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(yref.size(), x.size());
-    const T* x_data = x.data();
-    const T* yref_data = yref.data();
-    const int d = yref.size();
-    std::vector<T> ytgt(d);
-    T* ytgt_data = ytgt.data();
-    // test normal
-    tgt(x_data, ytgt_data, d);
-    ExpectEQ<T>(ytgt_data, yref_data, d);
-    // test inplace x
-    std::copy(x.begin(), x.end(), ytgt.begin());
-    tgt(ytgt_data, ytgt_data, d);
-    ExpectEQ<T>(ytgt_data, yref_data, d);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
-                         std::vector<T>, std::vector<T>, std::vector<T>,
-                         typename jit::LSTMTuples<T>::attr_type> {
-  void operator()(const typename jit::LSTMTuples<T>::func_type tgt,
-                  const std::vector<T>& xsrc, const std::vector<T>& wp,
-                  const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
-                  const std::vector<T>& ht_ref,
-                  const typename jit::LSTMTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(ct_ref.size(), ht_ref.size());
-    EXPECT_EQ(ct_1.size(), ht_ref.size());
-    EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
-    EXPECT_EQ(wp.size(), 3 * ht_ref.size());
-
-    // x could be changed after compute, so copy to save src
-    int d = ht_ref.size();
-    std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
-    std::vector<T> checked(2 * d);
-    std::copy(xsrc.begin(), xsrc.end(), x.begin());
-
-    const T* ct_1_data = ct_1.data();
-    const T* wp_data = wp.data();
-    const T* ct_ref_data = ct_ref.data();
-    const T* ht_ref_data = ht_ref.data();
-    T* x_data = x.data();
-    T* ct_data = ct.data();
-    T* ht_data = ht.data();
-    T* checked_data = checked.data();
-
-    jit::lstm_t step;
-    step.gates = x_data;
-    step.ct_1 = ct_1_data;
-    step.ct = ct_data;
-    step.ht = ht_data;
-    if (attr.use_peephole) {
-      step.wp = wp_data;
-      step.checked = checked_data;
-    }
-
-    tgt(&step, &attr);
-    ExpectEQ<T>(ct_data, ct_ref_data, d);
-    ExpectEQ<T>(ht_data, ht_ref_data, d);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
-                         std::vector<T>,
-                         typename jit::GRUTuples<T>::attr_type> {
-  void operator()(const typename jit::GRUTuples<T>::func_type tgt,
-                  const std::vector<T>& xsrc, const std::vector<T>& ht_1,
-                  const std::vector<T>& ht_ref,
-                  const typename jit::GRUTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(ht_1.size(), ht_ref.size());
-    EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
-
-    // x could be changed after compute, so copy to save src
-    int d = ht_ref.size();
-    std::vector<T> x(xsrc.size()), ht(ht_ref.size());
-    std::copy(xsrc.begin(), xsrc.end(), x.begin());
-    const T* ht_1_data = ht_1.data();
-    const T* ht_ref_data = ht_ref.data();
-    T* x_data = x.data();
-    T* ht_data = ht.data();
-    jit::gru_t step;
-    step.gates = x_data;
-    step.ht_1 = ht_1_data;
-    step.ht = ht_data;
-    tgt(&step, &attr);
-    ExpectEQ<T>(ht_data, ht_ref_data, d);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, std::vector<T>,
-                         typename jit::SeqPoolTuples<T>::attr_type> {
-  void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
-                  const std::vector<T>& x, const std::vector<T>& yref,
-                  const typename jit::SeqPoolTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(x.size() % yref.size(), static_cast<size_t>(0));
-    int w = yref.size();
-    std::vector<T> y(w);
-    const T* x_data = x.data();
-    const T* yref_data = yref.data();
-    T* y_data = y.data();
-    tgt(x_data, y_data, &attr);
-    ExpectEQ<T>(y_data, yref_data, w);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>,
-                         std::vector<int64_t>, std::vector<T>,
-                         typename jit::EmbSeqPoolTuples<T>::attr_type> {
-  void operator()(const typename jit::EmbSeqPoolTuples<T>::func_type tgt,
-                  const std::vector<T>& table, const std::vector<int64_t>& idx,
-                  const std::vector<T>& oref,
-                  const typename jit::EmbSeqPoolTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(table.size(),
-              static_cast<size_t>(attr.table_height * attr.table_width));
-    EXPECT_EQ(idx.size(),
-              static_cast<size_t>(attr.index_height * attr.index_width));
-    EXPECT_EQ(oref.size(),
-              static_cast<size_t>(attr.table_width * attr.index_width));
-    const T* table_data = table.data();
-    const int64_t* idx_data = idx.data();
-    const T* oref_data = oref.data();
-    int o_w = oref.size();
-    std::vector<T> out(o_w);
-    T* o_data = out.data();
-    tgt(table_data, idx_data, o_data, &attr);
-    ExpectEQ<T>(o_data, oref_data, o_w);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::SgdTuples<T>, T, std::vector<T>, std::vector<T>,
-                         std::vector<int64_t>, std::vector<T>,
-                         typename jit::SgdTuples<T>::attr_type> {
-  void operator()(const typename jit::SgdTuples<T>::func_type tgt, const T lr,
-                  const std::vector<T>& param, const std::vector<T>& grad,
-                  const std::vector<int64_t>& rows, const std::vector<T>& oref,
-                  const typename jit::SgdTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(param.size(),
-              static_cast<size_t>(attr.param_height * attr.param_width));
-    EXPECT_EQ(grad.size(),
-              static_cast<size_t>(attr.grad_height * attr.grad_width));
-    EXPECT_EQ(rows.size(), static_cast<size_t>(attr.selected_rows_size));
-    EXPECT_EQ(param.size(), oref.size());
-    const T* param_data = param.data();
-    const T* grad_data = grad.data();
-    const int64_t* rows_data = rows.data();
-    const T* oref_data = oref.data();
-
-    std::vector<T> out(oref.size());
-    T* o_data = out.data();
-    tgt(&lr, param_data, grad_data, rows_data, o_data, &attr);
-    // only the selected rows should be equal
-    for (size_t i = 0; i < rows.size(); ++i) {
-      ExpectEQ<T>(o_data + rows[i] * attr.grad_width,
-                  oref_data + rows[i] * attr.grad_width, attr.grad_width);
-    }
-
-    // inplace
-    std::copy(param.begin(), param.end(), out.begin());
-    tgt(&lr, o_data, grad_data, rows_data, o_data, &attr);
-    for (size_t i = 0; i < rows.size(); ++i) {
-      ExpectEQ<T>(o_data + rows[i] * attr.grad_width,
-                  oref_data + rows[i] * attr.grad_width, attr.grad_width);
-    }
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
-                         std::vector<T>,
-                         typename jit::MatMulTuples<T>::attr_type> {
-  void operator()(const typename jit::MatMulTuples<T>::func_type tgt,
-                  const std::vector<T>& a, const std::vector<T>& b,
-                  const std::vector<T>& cref,
-                  const typename jit::MatMulTuples<T>::attr_type& attr) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(a.size(), static_cast<size_t>(attr.m * attr.k));
-    EXPECT_EQ(b.size(), static_cast<size_t>(attr.k * attr.n));
-    EXPECT_EQ(cref.size(), static_cast<size_t>(attr.m * attr.n));
-    std::vector<T> c(cref.size());
-    const T* a_data = a.data();
-    const T* b_data = b.data();
-    const T* cref_data = cref.data();
-    T* c_data = c.data();
-    tgt(a_data, b_data, c_data, &attr);
-    ExpectEQ<T>(c_data, cref_data, attr.m * attr.n);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::LayerNormTuples<T>, std::vector<T>,
-                         std::vector<T>, std::vector<T>, std::vector<T>,
-                         std::vector<T>, std::vector<T>, int, float, int> {
-  void operator()(const typename jit::LayerNormTuples<T>::func_type tgt,
-                  std::vector<T>& x, std::vector<T>& outref,  // NOLINT
-                  std::vector<T>& mean, std::vector<T>& var,  // NOLINT
-                  const std::vector<T>& scale, const std::vector<T>& bias,
-                  int left, const float epsilon, int right) {
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(x.size(), static_cast<size_t>(left * right));
-    EXPECT_EQ(outref.size(), static_cast<size_t>(left * right));
-    EXPECT_EQ(mean.size(), static_cast<size_t>(left));
-    EXPECT_EQ(var.size(), static_cast<size_t>(left));
-    EXPECT_EQ(scale.size(), static_cast<size_t>(right));
-    EXPECT_EQ(bias.size(), static_cast<size_t>(right));
-    std::vector<T> outtgt(outref.size());
-    const T* scale_data = scale.data();
-    const T* bias_data = bias.data();
-    T* x_data = x.data();
-    T* mean_data = mean.data();
-    T* var_data = var.data();
-    T* outref_data = outref.data();
-    T* outtgt_data = outtgt.data();
-
-    tgt(x_data, outtgt_data, mean_data, var_data, scale_data, bias_data, left,
-        epsilon, right);
-    ExpectEQ<T>(outtgt_data, outref_data, left * right);
-  }
-};
-
-template <typename T>
-struct TestFuncWithRefer<jit::CRFDecodingTuples<T>, int, std::vector<T>,
-                         std::vector<T>, std::vector<T>, std::vector<int>,
-                         int> {
-  void operator()(const typename jit::CRFDecodingTuples<T>::func_type tgt,
-                  const int seq_len, const std::vector<T>& x,
-                  const std::vector<T>& w, std::vector<T>& alpharef,  // NOLINT
-                  std::vector<int>& trackref, int tag_num) {          // NOLINT
-    constexpr int state_trans_base_idx = 2;
-    EXPECT_TRUE(tgt != nullptr);
-    EXPECT_EQ(x.size(), static_cast<size_t>(seq_len * tag_num));
-    EXPECT_EQ(w.size(),
-              static_cast<size_t>((tag_num + state_trans_base_idx) * tag_num));
-    EXPECT_EQ(alpharef.size(), static_cast<size_t>(seq_len * tag_num));
-    EXPECT_EQ(trackref.size(), static_cast<size_t>(seq_len * tag_num));
-    std::vector<T> alphatgt(alpharef.size());
-    std::vector<int> tracktgt(trackref.size());
-
-    memcpy(trackref.data(), tracktgt.data(), tag_num * sizeof(int));
-    tgt(seq_len, (const T*)x.data(), (const T*)w.data(), alphatgt.data(),
-        tracktgt.data(), tag_num);
-    ExpectEQ<T>(alpharef.data(), alphatgt.data(), seq_len * tag_num);
-    ExpectEQ<int>(trackref.data(), tracktgt.data(), seq_len * tag_num);
-  }
-};
-
-template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
+template <typename KernelTuple, typename PlaceType, typename Tester,
           typename... Args>
-void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
-  TestFuncWithRefer<KernelTuples, Args...> test;
-  // test jitcode
-  auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
-  if (jitcode) {
-    VLOG(10) << "Test Jitcode Kernel ";
-    test(jitcode, args...);
+void TestAllImpls(const typename KernelTuple::attr_type& attr,
+                  const Tester& verifier, const Args&... args) {
+  auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
+  for (auto f : funcs) {
+    VLOG(10) << "Test Kernel " << f.first;
+    verifier(f.second, args...);
   }
-  // test all impls in more
-  jit::KernelKey kkey(KT, PlaceType());
-  auto& pool = jit::KernelPool().Instance().AllKernels();
-  auto iter = pool.find(kkey);
-  if (iter != pool.end()) {
-    auto& impls = iter->second;
-    for (auto& impl : impls) {
-      auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
-      if (i && i->UseMe(attr)) {
-        auto more = i->GetFunc();
-        VLOG(10) << "Test More Kernel : " << i->ImplType();
-        test(more, args...);
-      }
-    }
-  }
-  // test result from Get function
-  // VLOG(10) << "Test Get function ";
-  auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
-  test(tgt, args...);
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelXYZNTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelXYZN() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   for (int d : TestSizes()) {
-    auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
+    auto ref = jit::GetReferFunc<KernelTuple>();
     EXPECT_TRUE(ref != nullptr);
 
     std::vector<T> x(d), y(d), zref(d);
@@ -494,16 +105,42 @@ void TestKernelXYZNTuples() {
     ExpectEQ<T>(xinp_data, zref_data, d);
     ExpectEQ<T>(yinp_data, zref_data, d);
 
-    TestAllImpls<KT, jit::XYZNTuples<T>, PlaceType, std::vector<T>,
-                 std::vector<T>, std::vector<T>>(d, x, y, zref);
+    auto verifier = [](const typename KernelTuple::func_type tgt,
+                       const std::vector<T>& x, const std::vector<T>& y,
+                       const std::vector<T>& zref) {
+      EXPECT_TRUE(tgt != nullptr);
+      EXPECT_EQ(zref.size(), x.size());
+      EXPECT_EQ(zref.size(), y.size());
+      const T* x_data = x.data();
+      const T* y_data = y.data();
+      const T* zref_data = zref.data();
+      const int d = zref.size();
+
+      std::vector<T> ztgt(d);
+      T* ztgt_data = ztgt.data();
+      // test normal
+      tgt(x_data, y_data, ztgt_data, d);
+      ExpectEQ<T>(ztgt_data, zref_data, d);
+      // test inplace x
+      std::copy(x.begin(), x.end(), ztgt.begin());
+      tgt(ztgt_data, y_data, ztgt_data, d);
+      ExpectEQ<T>(ztgt_data, zref_data, d);
+      // test inplace y
+      std::copy(y.begin(), y.end(), ztgt.begin());
+      tgt(x_data, ztgt_data, ztgt_data, d);
+      ExpectEQ<T>(ztgt_data, zref_data, d);
+    };
+
+    TestAllImpls<KernelTuple, PlaceType>(d, verifier, x, y, zref);
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelAXYNTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelAXYN() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   for (int d : TestSizes()) {
-    auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
+    auto ref = jit::GetReferFunc<KernelTuple>();
     EXPECT_TRUE(ref != nullptr);
 
     const T a = static_cast<T>(3);
@@ -520,34 +157,33 @@ void TestKernelAXYNTuples() {
     ref(&a, xinp_data, xinp_data, d);
     ExpectEQ<T>(xinp_data, yref_data, d);
 
-    TestAllImpls<KT, jit::AXYNTuples<T>, PlaceType, T, std::vector<T>,
-                 std::vector<T>>(d, a, x, yref);
-  }
-}
-
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelXRNTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  auto last_acc = FLAGS_acc;
-  FLAGS_acc = 1e-4;
-  for (int d : TestSizes()) {
-    auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
-    EXPECT_TRUE(ref != nullptr);
-    std::vector<T> x(d);
-    RandomVec<T>(d, x.data());
-    T ref_res;
-    ref(x.data(), &ref_res, d);
-    TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
-                                                                      ref_res);
+    auto verifier = [](const typename KernelTuple::func_type tgt, const T a,
+                       const std::vector<T>& x, const std::vector<T>& yref) {
+      EXPECT_TRUE(tgt != nullptr);
+      EXPECT_EQ(yref.size(), x.size());
+      const T* x_data = x.data();
+      const T* yref_data = yref.data();
+      const int d = yref.size();
+      std::vector<T> ytgt(d);
+      T* ytgt_data = ytgt.data();
+      // test normal
+      tgt(&a, x_data, ytgt_data, d);
+      ExpectEQ<T>(ytgt_data, yref_data, d);
+      // test inplace x
+      std::copy(x.begin(), x.end(), ytgt.begin());
+      tgt(&a, ytgt_data, ytgt_data, d);
+      ExpectEQ<T>(ytgt_data, yref_data, d);
+    };
+    TestAllImpls<KernelTuple, PlaceType>(d, verifier, a, x, yref);
   }
-  FLAGS_acc = last_acc;
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelXYNTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelXYN() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   for (int d : TestSizes()) {
-    auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
+    auto ref = jit::GetReferFunc<KernelTuple>();
     EXPECT_TRUE(ref != nullptr);
 
     std::vector<T> x(d), yref(d);
@@ -562,15 +198,57 @@ void TestKernelXYNTuples() {
     ref(x_data, yref_data, d);
     ref(xinp_data, xinp_data, d);
     ExpectEQ<T>(xinp_data, yref_data, d);
+    auto verifier = [](const typename KernelTuple::func_type tgt,
+                       const std::vector<T>& x, const std::vector<T>& yref) {
+      EXPECT_TRUE(tgt != nullptr);
+      EXPECT_EQ(yref.size(), x.size());
+      const T* x_data = x.data();
+      const T* yref_data = yref.data();
+      const int d = yref.size();
+      std::vector<T> ytgt(d);
+      T* ytgt_data = ytgt.data();
+      // test normal
+      tgt(x_data, ytgt_data, d);
+      ExpectEQ<T>(ytgt_data, yref_data, d);
+      // test inplace x
+      std::copy(x.begin(), x.end(), ytgt.begin());
+      tgt(ytgt_data, ytgt_data, d);
+      ExpectEQ<T>(ytgt_data, yref_data, d);
+    };
+    TestAllImpls<KernelTuple, PlaceType>(d, verifier, x, yref);
+  }
+}
 
-    TestAllImpls<KT, jit::XYNTuples<T>, PlaceType, std::vector<T>,
-                 std::vector<T>>(d, x, yref);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelXRN() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  auto last_acc = FLAGS_acc;
+  FLAGS_acc = 1e-4;
+  for (int d : TestSizes()) {
+    auto ref = jit::GetReferFunc<KernelTuple>();
+    EXPECT_TRUE(ref != nullptr);
+    std::vector<T> x(d);
+    RandomVec<T>(d, x.data());
+    T ref_res;
+    ref(x.data(), &ref_res, d);
+
+    auto verifier = [](const typename KernelTuple::func_type tgt,
+                       const std::vector<T>& x, const T ref_res) {
+      EXPECT_TRUE(tgt != nullptr);
+      T tgt_res;
+      tgt(x.data(), &tgt_res, x.size());
+      ExpectEQ<T>(&tgt_res, &ref_res, 1);
+    };
+    TestAllImpls<KernelTuple, PlaceType>(d, verifier, x, ref_res);
   }
+  FLAGS_acc = last_acc;
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelLSTMTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelLSTM() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
   auto test_sizes = TestSizes();
   test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
@@ -582,7 +260,7 @@ void TestKernelLSTMTuples() {
             const jit::lstm_attr_t attr(
                 d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
                 jit::to_kerneltype(act_cell), use_peephole);
-            auto ref = jit::GetRefer<KT, jit::LSTMTuples<T>>();
+            auto ref = jit::GetReferFunc<KernelTuple>();
             EXPECT_TRUE(ref != nullptr);
             std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
             std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
@@ -609,10 +287,51 @@ void TestKernelLSTMTuples() {
             }
             ref(&step, &attr);
             VLOG(10) << attr;
-            TestAllImpls<KT, jit::LSTMTuples<T>, PlaceType, std::vector<T>,
-                         std::vector<T>, std::vector<T>, std::vector<T>,
-                         std::vector<T>>(attr, xsrc, wp, ct_1, ct_ref, ht_ref,
-                                         attr);
+
+            auto verifier = [](
+                const typename KernelTuple::func_type tgt,
+                const std::vector<T>& xsrc, const std::vector<T>& wp,
+                const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
+                const std::vector<T>& ht_ref,
+                const typename KernelTuple::attr_type& attr) {
+              EXPECT_TRUE(tgt != nullptr);
+              EXPECT_EQ(ct_ref.size(), ht_ref.size());
+              EXPECT_EQ(ct_1.size(), ht_ref.size());
+              EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
+              EXPECT_EQ(wp.size(), 3 * ht_ref.size());
+
+              // x could be changed after compute, so copy to save src
+              int d = ht_ref.size();
+              std::vector<T> x(xsrc.size()), ct(ct_ref.size()),
+                  ht(ht_ref.size());
+              std::vector<T> checked(2 * d);
+              std::copy(xsrc.begin(), xsrc.end(), x.begin());
+
+              const T* ct_1_data = ct_1.data();
+              const T* wp_data = wp.data();
+              const T* ct_ref_data = ct_ref.data();
+              const T* ht_ref_data = ht_ref.data();
+              T* x_data = x.data();
+              T* ct_data = ct.data();
+              T* ht_data = ht.data();
+              T* checked_data = checked.data();
+
+              jit::lstm_t step;
+              step.gates = x_data;
+              step.ct_1 = ct_1_data;
+              step.ct = ct_data;
+              step.ht = ht_data;
+              if (attr.use_peephole) {
+                step.wp = wp_data;
+                step.checked = checked_data;
+              }
+
+              tgt(&step, &attr);
+              ExpectEQ<T>(ct_data, ct_ref_data, d);
+              ExpectEQ<T>(ht_data, ht_ref_data, d);
+            };
+            TestAllImpls<KernelTuple, PlaceType>(attr, verifier, xsrc, wp, ct_1,
+                                                 ct_ref, ht_ref, attr);
           }
         }
       }
@@ -620,9 +339,10 @@ void TestKernelLSTMTuples() {
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelGRUTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelGRU() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
   auto test_sizes = TestSizes();
   test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
@@ -631,7 +351,7 @@ void TestKernelGRUTuples() {
       for (auto& act_cand : all_acts) {
         const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
                                    jit::to_kerneltype(act_cand));
-        auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>();
+        auto ref = jit::GetReferFunc<KernelTuple>();
         EXPECT_TRUE(ref != nullptr);
         std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
         RandomVec<T>(3 * d, xsrc.data());
@@ -648,17 +368,218 @@ void TestKernelGRUTuples() {
         step.ht = ht_ref_data;
         ref(&step, &attr);
         VLOG(10) << attr;
-        TestAllImpls<KT, jit::GRUTuples<T>, PlaceType, std::vector<T>,
-                     std::vector<T>, std::vector<T>>(attr, xsrc, ht_1, ht_ref,
-                                                     attr);
+        auto verifier = [](const typename KernelTuple::func_type tgt,
+                           const std::vector<T>& xsrc,
+                           const std::vector<T>& ht_1,
+                           const std::vector<T>& ht_ref,
+                           const typename KernelTuple::attr_type& attr) {
+          EXPECT_TRUE(tgt != nullptr);
+          EXPECT_EQ(ht_1.size(), ht_ref.size());
+          EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
+
+          // x could be changed after compute, so copy to save src
+          int d = ht_ref.size();
+          std::vector<T> x(xsrc.size()), ht(ht_ref.size());
+          std::copy(xsrc.begin(), xsrc.end(), x.begin());
+          const T* ht_1_data = ht_1.data();
+          const T* ht_ref_data = ht_ref.data();
+          T* x_data = x.data();
+          T* ht_data = ht.data();
+          jit::gru_t step;
+          step.gates = x_data;
+          step.ht_1 = ht_1_data;
+          step.ht = ht_data;
+          tgt(&step, &attr);
+          ExpectEQ<T>(ht_data, ht_ref_data, d);
+        };
+        TestAllImpls<KernelTuple, PlaceType>(attr, verifier, xsrc, ht_1, ht_ref,
+                                             attr);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelSeqPoolTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelNCHW16CMulNC() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  const int n = 3, c = 16 * 4, h = 10, w = 10;
+  auto ref = jit::GetReferFunc<KernelTuple>();
+  EXPECT_TRUE(ref != nullptr);
+  int sz = n * c * h * w;
+  std::vector<T> x(sz), y(n * c), zref(sz);
+  std::vector<T> ztgt(sz), zjit(sz);
+  RandomVec<T>(sz, x.data());
+  RandomVec<T>(n * c, y.data());
+
+  const T* x_data = x.data();
+  const T* y_data = y.data();
+  T* zref_data = zref.data();
+  T* ztgt_data = ztgt.data();
+  T* zjit_data = zjit.data();
+  constexpr int simd_width = ZMM_FLOAT_BLOCK;
+  int C = c / simd_width;
+  auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(0);
+  auto funcs = jit::GetAllCandidateFuncs<KernelTuple, PlaceType>(0);
+  EXPECT_GT(funcs.size(), 0UL);
+  auto jitcode = funcs[0];
+  EXPECT_TRUE(tgt != nullptr);
+
+  if (std::is_same<T, float>::value &&
+      paddle::platform::MayIUse(paddle::platform::avx512f)) {
+    EXPECT_TRUE(jitcode != nullptr);
+  }
+  for (int ni = 0; ni < n; ni++) {
+    for (int ci = 0; ci < C; ci++) {
+      auto ptr_x =
+          x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
+      auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
+      auto ptr_zref =
+          zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
+      auto ptr_ztgt =
+          ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
+
+      ref(ptr_x, ptr_y, ptr_zref, h, w);
+      tgt(ptr_x, ptr_y, ptr_ztgt, h, w);
+
+      if (jitcode) {
+        auto ptr_zjit =
+            zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
+        jitcode(ptr_x, ptr_y, ptr_zjit, h, w);
+      }
+    }
+  }
+  ExpectEQ<T>(ztgt_data, zref_data, sz);
+  if (jitcode) {
+    ExpectEQ<T>(zjit_data, zref_data, sz);
+  }
+}
+
+template <typename KernelTuple, typename PlaceType>
+void TestKernelLayerNorm() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  const T epsilon = 9.99999975e-06;
+  for (int n : {1, 2, 10}) {
+    for (int x_dim_0 : {1, 9, 17, 50}) {
+      int left = n * x_dim_0;
+      for (int x_dim_1 : TestSizes()) {
+        int right = x_dim_1;
+        auto ref = jit::GetReferFunc<KernelTuple>();
+        EXPECT_TRUE(ref != nullptr);
+        int sz = left * right;
+        std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right),
+            outref(sz);
+        RandomVec<T>(sz, x.data());
+        RandomVec<T>(left, mean.data());
+        RandomVec<T>(left, var.data());
+        RandomVec<T>(right, scale.data());
+        RandomVec<T>(right, bias.data());
+
+        const T* scale_data = scale.data();
+        const T* bias_data = bias.data();
+        T* x_data = x.data();
+        T* mean_data = mean.data();
+        T* var_data = var.data();
+        T* outref_data = outref.data();
+
+        ref(x_data, outref_data, mean_data, var_data, scale_data, bias_data,
+            left, epsilon, right);
+
+        auto verifier = [](
+            const typename KernelTuple::func_type tgt, const std::vector<T>& x_,
+            const std::vector<T>& outref_, const std::vector<T>& mean_,
+            const std::vector<T>& var_, const std::vector<T>& scale,
+            const std::vector<T>& bias, const int& left, const float& epsilon,
+            const typename KernelTuple::attr_type& right) {
+          EXPECT_TRUE(tgt != nullptr);
+          std::vector<T> outtgt(outref_.size());
+          std::vector<T> x(x_.size());
+          std::vector<T> mean(mean_.size());
+          std::vector<T> var(var_.size());
+          std::vector<T> outref(outref_.size());
+          std::copy(x_.begin(), x_.end(), x.begin());
+          std::copy(mean_.begin(), mean_.end(), mean.begin());
+          std::copy(var_.begin(), var_.end(), var.begin());
+          std::copy(outref_.begin(), outref_.end(), outref.begin());
+
+          EXPECT_EQ(x.size(), static_cast<size_t>(left * right));
+          EXPECT_EQ(outref.size(), static_cast<size_t>(left * right));
+          EXPECT_EQ(mean.size(), static_cast<size_t>(left));
+          EXPECT_EQ(var.size(), static_cast<size_t>(left));
+          EXPECT_EQ(scale.size(), static_cast<size_t>(right));
+          EXPECT_EQ(bias.size(), static_cast<size_t>(right));
+
+          const T* scale_data = scale.data();
+          const T* bias_data = bias.data();
+          T* x_data = x.data();
+          T* mean_data = mean.data();
+          T* var_data = var.data();
+          T* outref_data = outref.data();
+          T* outtgt_data = outtgt.data();
+          tgt(x_data, outtgt_data, mean_data, var_data, scale_data, bias_data,
+              left, epsilon, right);
+          ExpectEQ<T>(outtgt_data, outref_data, left * right);
+        };
+        TestAllImpls<KernelTuple, PlaceType>(right, verifier, x, outref, mean,
+                                             var, scale, bias, left, epsilon,
+                                             right);
+      }
+    }
+  }
+}
+
+template <typename KernelTuple, typename PlaceType>
+void TestKernelCRFDecoding() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  constexpr int state_trans_base_idx = 2;
+  auto test_sizes = TestSizes();
+  test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000));
+  for (int seq_len : {1, 11, 17, 50}) {
+    for (int tag_num : test_sizes) {
+      auto ref = jit::GetReferFunc<KernelTuple>();
+      EXPECT_TRUE(ref != nullptr);
+      int x_sz = seq_len * tag_num;
+      int w_sz = (tag_num + state_trans_base_idx) * tag_num;
+      std::vector<T> x(x_sz), w(w_sz), alpharef(x_sz);
+      std::vector<int> trackref(x_sz);
+      RandomVec<T>(x_sz, x.data());
+      RandomVec<T>(w_sz, w.data());
+
+      ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(),
+          trackref.data(), tag_num);
+
+      auto verifier = [](
+          const typename KernelTuple::func_type tgt, const int& seq_len,
+          const std::vector<T>& x, const std::vector<T>& w,
+          const std::vector<T>& alpharef, const std::vector<int>& trackref,
+          const typename KernelTuple::attr_type& tag_num) {
+        constexpr int state_trans_base_idx = 2;
+        EXPECT_TRUE(tgt != nullptr);
+        EXPECT_EQ(x.size(), static_cast<size_t>(seq_len * tag_num));
+        EXPECT_EQ(w.size(), static_cast<size_t>(
+                                (tag_num + state_trans_base_idx) * tag_num));
+        EXPECT_EQ(alpharef.size(), static_cast<size_t>(seq_len * tag_num));
+        EXPECT_EQ(trackref.size(), static_cast<size_t>(seq_len * tag_num));
+        std::vector<T> alphatgt(alpharef.size());
+        std::vector<int> tracktgt(trackref.size());
+        memcpy(tracktgt.data(), trackref.data(), tag_num * sizeof(int));
+        tgt(seq_len, (const T*)x.data(), (const T*)w.data(), alphatgt.data(),
+            tracktgt.data(), tag_num);
+        ExpectEQ<T>(alpharef.data(), alphatgt.data(), seq_len * tag_num);
+        ExpectEQ<int>(trackref.data(), tracktgt.data(), seq_len * tag_num);
+      };
+      TestAllImpls<KernelTuple, PlaceType>(tag_num, verifier, seq_len, x, w,
+                                           alpharef, trackref, tag_num);
+    }
+  }
+}
+
+template <typename KernelTuple, typename PlaceType>
+void TestKernelSeqPool() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   std::vector<jit::SeqPoolType> pool_types = {
       jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
   auto test_sizes = TestSizes();
@@ -668,7 +589,7 @@ void TestKernelSeqPoolTuples() {
       jit::seq_pool_attr_t attr(w, type);
       for (int h : test_sizes) {
         attr.h = h;
-        auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
+        auto ref = jit::GetReferFunc<KernelTuple>();
         EXPECT_TRUE(ref != nullptr);
         std::vector<T> x(h * w), yref(w);
         RandomVec<T>(h * w, x.data());
@@ -676,16 +597,86 @@ void TestKernelSeqPoolTuples() {
         T* yref_data = yref.data();
         ref(x_data, yref_data, &attr);
         VLOG(10) << attr;
-        TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
-                     std::vector<T>>(attr, x, yref, attr);
+        auto verifier = [](const typename KernelTuple::func_type tgt,
+                           const std::vector<T>& x, const std::vector<T>& yref,
+                           const typename KernelTuple::attr_type& attr) {
+          EXPECT_TRUE(tgt != nullptr);
+          EXPECT_EQ(x.size() % yref.size(), static_cast<size_t>(0));
+          int w = yref.size();
+          std::vector<T> y(w);
+          const T* x_data = x.data();
+          const T* yref_data = yref.data();
+          T* y_data = y.data();
+          tgt(x_data, y_data, &attr);
+          ExpectEQ<T>(y_data, yref_data, w);
+        };
+        TestAllImpls<KernelTuple, PlaceType>(attr, verifier, x, yref, attr);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelMatMulTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelEmbSeqPool() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  int64_t tbl_h = 1e4;
+  std::vector<jit::SeqPoolType> pool_types = {
+      jit::SeqPoolType::kSum};  // only support sum yet
+  auto test_sizes = TestSizes();
+  test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
+  for (int tbl_w : test_sizes) {
+    std::vector<T> table(tbl_h * tbl_w);
+    RandomVec<T>(tbl_h * tbl_w, table.data());
+    const T* table_data = table.data();
+    for (auto type : pool_types) {
+      for (int idx_w : {1, 2, 10, 16}) {
+        for (int idx_h : {1, 2, 9, 13, 16}) {
+          auto ref = jit::GetReferFunc<KernelTuple>();
+          EXPECT_TRUE(ref != nullptr);
+          std::vector<int64_t> idx(idx_h * idx_w);
+          RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1);
+          int64_t out_w = tbl_w * idx_w;
+          std::vector<T> oref(out_w);
+          const int64_t* idx_data = idx.data();
+          T* o_data = oref.data();
+          jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
+                                        type);
+          ref(table_data, idx_data, o_data, &attr);
+
+          auto verifier = [](const typename KernelTuple::func_type tgt,
+                             const std::vector<T>& table,
+                             const std::vector<int64_t>& idx,
+                             const std::vector<T>& oref,
+                             const typename KernelTuple::attr_type& attr) {
+            EXPECT_TRUE(tgt != nullptr);
+            EXPECT_EQ(table.size(), static_cast<size_t>(attr.table_height *
+                                                        attr.table_width));
+            EXPECT_EQ(idx.size(), static_cast<size_t>(attr.index_height *
+                                                      attr.index_width));
+            EXPECT_EQ(oref.size(),
+                      static_cast<size_t>(attr.table_width * attr.index_width));
+            const T* table_data = table.data();
+            const int64_t* idx_data = idx.data();
+            const T* oref_data = oref.data();
+            int o_w = oref.size();
+            std::vector<T> out(o_w);
+            T* o_data = out.data();
+            tgt(table_data, idx_data, o_data, &attr);
+            ExpectEQ<T>(o_data, oref_data, o_w);
+          };
+          TestAllImpls<KernelTuple, PlaceType>(attr, verifier, table, idx, oref,
+                                               attr);
+        }
+      }
+    }
+  }
+}
+
+template <typename KernelTuple, typename PlaceType>
+void TestKernelMatMul() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   auto last_acc = FLAGS_acc;
   // export MKL_CBWR=AVX would make MKL force to use AVX
   // export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic
@@ -693,7 +684,7 @@ void TestKernelMatMulTuples() {
   for (int m : {1, 2, 3, 4}) {
     for (int n : {1, 2, 3, 4}) {
       for (int k : TestSizes()) {
-        auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
+        auto ref = jit::GetReferFunc<KernelTuple>();
         EXPECT_TRUE(ref != nullptr);
         std::vector<T> a(m * k), b(k * n), c(m * n);
         RandomVec<T>(m * k, a.data());
@@ -703,20 +694,36 @@ void TestKernelMatMulTuples() {
         T* c_data = c.data();
         const jit::matmul_attr_t attr{m, n, k};
         ref(a_data, b_data, c_data, &attr);
-        TestAllImpls<KT, jit::MatMulTuples<T>, PlaceType, std::vector<T>,
-                     std::vector<T>, std::vector<T>>(attr, a, b, c, attr);
+        auto verifier = [](const typename KernelTuple::func_type tgt,
+                           const std::vector<T>& a, const std::vector<T>& b,
+                           const std::vector<T>& cref,
+                           const typename KernelTuple::attr_type& attr) {
+          EXPECT_TRUE(tgt != nullptr);
+          EXPECT_EQ(a.size(), static_cast<size_t>(attr.m * attr.k));
+          EXPECT_EQ(b.size(), static_cast<size_t>(attr.k * attr.n));
+          EXPECT_EQ(cref.size(), static_cast<size_t>(attr.m * attr.n));
+          std::vector<T> c(cref.size());
+          const T* a_data = a.data();
+          const T* b_data = b.data();
+          const T* cref_data = cref.data();
+          T* c_data = c.data();
+          tgt(a_data, b_data, c_data, &attr);
+          ExpectEQ<T>(c_data, cref_data, attr.m * attr.n);
+        };
+        TestAllImpls<KernelTuple, PlaceType>(attr, verifier, a, b, c, attr);
       }
     }
   }
   FLAGS_acc = last_acc;
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelSoftmaxTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelSoftmax() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   for (int bs : {1, 2, 10}) {
     for (int n : TestSizes()) {
-      auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>();
+      auto ref = jit::GetReferFunc<KernelTuple>();
       EXPECT_TRUE(ref != nullptr);
       std::vector<T> x(bs * n), y(bs * n);
       RandomVec<T>(bs * n, x.data());
@@ -730,51 +737,33 @@ void TestKernelSoftmaxTuples() {
       ref(xinp_data, xinp_data, n, bs);
       ExpectEQ<T>(xinp_data, y_data, n * bs);
 
-      TestAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType, std::vector<T>,
-                   std::vector<T>>(n, x, y, n, bs);
-    }
-  }
-}
-
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelEmbSeqPoolTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  int64_t tbl_h = 1e4;
-  std::vector<jit::SeqPoolType> pool_types = {
-      jit::SeqPoolType::kSum};  // only support sum yet
-  auto test_sizes = TestSizes();
-  test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
-  for (int tbl_w : test_sizes) {
-    std::vector<T> table(tbl_h * tbl_w);
-    RandomVec<T>(tbl_h * tbl_w, table.data());
-    const T* table_data = table.data();
-    for (auto type : pool_types) {
-      for (int idx_w : {1, 2, 10, 16}) {
-        for (int idx_h : {1, 2, 9, 13, 16}) {
-          auto ref = jit::GetRefer<KT, jit::EmbSeqPoolTuples<T>>();
-          EXPECT_TRUE(ref != nullptr);
-          std::vector<int64_t> idx(idx_h * idx_w);
-          RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1);
-          int64_t out_w = tbl_w * idx_w;
-          std::vector<T> oref(out_w);
-          const int64_t* idx_data = idx.data();
-          T* o_data = oref.data();
-          jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
-                                        type);
-          ref(table_data, idx_data, o_data, &attr);
-
-          TestAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType, std::vector<T>,
-                       std::vector<int64_t>, std::vector<T>>(attr, table, idx,
-                                                             oref, attr);
-        }
-      }
+      auto verifier = [](const typename KernelTuple::func_type tgt,
+                         const std::vector<T>& x, const std::vector<T>& yref,
+                         int n, int bs) {
+        EXPECT_TRUE(tgt != nullptr);
+        EXPECT_EQ(yref.size(), x.size());
+        EXPECT_EQ(x.size(), static_cast<size_t>(n * bs));
+        const T* x_data = x.data();
+        const T* yref_data = yref.data();
+        std::vector<T> ytgt(n * bs);
+        T* ytgt_data = ytgt.data();
+        // test normal
+        tgt(x_data, ytgt_data, n, bs);
+        ExpectEQ<T>(ytgt_data, yref_data, n * bs);
+        // test inplace x
+        std::copy(x.begin(), x.end(), ytgt.begin());
+        tgt(ytgt_data, ytgt_data, n, bs);
+        ExpectEQ<T>(ytgt_data, yref_data, n * bs);
+      };
+      TestAllImpls<KernelTuple, PlaceType>(n, verifier, x, y, n, bs);
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelSgdTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelSgd() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
   const T lr = 0.1;
   auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
                                   const int64_t upper) -> std::vector<int64_t> {
@@ -802,7 +791,7 @@ void TestKernelSgdTuples() {
         RandomVec<T>(rows_size * grad_w, grad.data());
         const int64_t* rows_data = rows.data();
         const T* grad_data = grad.data();
-        auto ref = jit::GetRefer<KT, jit::SgdTuples<T>>();
+        auto ref = jit::GetReferFunc<KernelTuple>();
         EXPECT_TRUE(ref != nullptr);
         jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
         ref(&lr, param_data, grad_data, rows_data, out_data, &attr);
@@ -818,227 +807,488 @@ void TestKernelSgdTuples() {
                       grad_w);
         }
 
-        TestAllImpls<KT, jit::SgdTuples<T>, PlaceType, T, std::vector<T>,
-                     std::vector<T>, std::vector<int64_t>, std::vector<T>>(
-            attr, lr, param, grad, rows, param_out, attr);
+        auto verifier = [](
+            const typename KernelTuple::func_type tgt, const T lr,
+            const std::vector<T>& param, const std::vector<T>& grad,
+            const std::vector<int64_t>& rows, const std::vector<T>& oref,
+            const typename KernelTuple::attr_type& attr) {
+          EXPECT_TRUE(tgt != nullptr);
+          EXPECT_EQ(param.size(),
+                    static_cast<size_t>(attr.param_height * attr.param_width));
+          EXPECT_EQ(grad.size(),
+                    static_cast<size_t>(attr.grad_height * attr.grad_width));
+          EXPECT_EQ(rows.size(), static_cast<size_t>(attr.selected_rows_size));
+          EXPECT_EQ(param.size(), oref.size());
+          const T* param_data = param.data();
+          const T* grad_data = grad.data();
+          const int64_t* rows_data = rows.data();
+          const T* oref_data = oref.data();
+
+          std::vector<T> out(oref.size());
+          T* o_data = out.data();
+          tgt(&lr, param_data, grad_data, rows_data, o_data, &attr);
+          // only the selected rows should be equal
+          for (size_t i = 0; i < rows.size(); ++i) {
+            ExpectEQ<T>(o_data + rows[i] * attr.grad_width,
+                        oref_data + rows[i] * attr.grad_width, attr.grad_width);
+          }
+
+          // inplace
+          std::copy(param.begin(), param.end(), out.begin());
+          tgt(&lr, o_data, grad_data, rows_data, o_data, &attr);
+          for (size_t i = 0; i < rows.size(); ++i) {
+            ExpectEQ<T>(o_data + rows[i] * attr.grad_width,
+                        oref_data + rows[i] * attr.grad_width, attr.grad_width);
+          }
+        };
+        TestAllImpls<KernelTuple, PlaceType>(attr, verifier, lr, param, grad,
+                                             rows, param_out, attr);
       }
     }
   }
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelNCHW16CMulNCTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  const int n = 3, c = 16 * 4, h = 10, w = 10;
-  auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
-  EXPECT_TRUE(ref != nullptr);
-  int sz = n * c * h * w;
-  std::vector<T> x(sz), y(n * c), zref(sz);
-  std::vector<T> ztgt(sz), zjit(sz);
-  RandomVec<T>(sz, x.data());
-  RandomVec<T>(n * c, y.data());
-
-  const T* x_data = x.data();
-  const T* y_data = y.data();
-  T* zref_data = zref.data();
-  T* ztgt_data = ztgt.data();
-  T* zjit_data = zjit.data();
-  constexpr int simd_width = ZMM_FLOAT_BLOCK;
-  int C = c / simd_width;
-  auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
-  auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
-  EXPECT_TRUE(tgt != nullptr);
-
-  if (std::is_same<T, float>::value &&
-      paddle::platform::MayIUse(paddle::platform::avx512f)) {
-    EXPECT_TRUE(jitcode != nullptr);
-  }
-  for (int ni = 0; ni < n; ni++) {
-    for (int ci = 0; ci < C; ci++) {
-      auto ptr_x =
-          x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
-      auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
-      auto ptr_zref =
-          zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
-      auto ptr_ztgt =
-          ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
-
-      ref(ptr_x, ptr_y, ptr_zref, h, w);
-      tgt(ptr_x, ptr_y, ptr_ztgt, h, w);
+template <typename KernelTuple, typename PlaceType>
+void TestKernelVBroadcast() {
+  using T = typename KernelTuple::data_type;
+  VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
+  for (int w : TestSizes()) {
+    std::vector<T> x(w);
+    RandomVec<T>(w, x.data());
+    const T* x_data = x.data();
+    for (int64_t h : {1, 2, 6}) {
+      auto ref = jit::GetReferFunc<KernelTuple>();
+      EXPECT_TRUE(ref != nullptr);
+      std::vector<T> y(w * h);
+      T* y_data = y.data();
+      ref(x_data, y_data, h, w);
 
-      if (jitcode) {
-        auto ptr_zjit =
-            zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
-        jitcode(ptr_x, ptr_y, ptr_zjit, h, w);
-      }
+      auto verifier = [](const typename KernelTuple::func_type tgt,
+                         const std::vector<T>& x, const std::vector<T>& yref,
+                         const int64_t& h,
+                         const typename KernelTuple::attr_type& attr) {
+        EXPECT_TRUE(tgt != nullptr);
+        EXPECT_EQ(x.size(), static_cast<size_t>(attr));
+        EXPECT_EQ(yref.size(), x.size() * h);
+        std::vector<T> y(yref.size());
+        const T* x_data = x.data();
+        const T* yref_data = yref.data();
+        T* y_data = y.data();
+        tgt(x_data, y_data, h, attr);
+        ExpectEQ<T>(y_data, yref_data, yref.size());
+      };
+      TestAllImpls<KernelTuple, PlaceType>(static_cast<int64_t>(w), verifier, x,
+                                           y, h, static_cast<int64_t>(w));
     }
   }
-  ExpectEQ<T>(ztgt_data, zref_data, sz);
-  if (jitcode) {
-    ExpectEQ<T>(zjit_data, zref_data, sz);
-  }
 }
 
-template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelLayerNormTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  const T epsilon = 9.99999975e-06;
-  for (int n : {1, 2, 10}) {
-    for (int x_dim_0 : {1, 9, 17, 50}) {
-      int left = n * x_dim_0;
-      for (int x_dim_1 : TestSizes()) {
-        int right = x_dim_1;
-        auto ref = jit::GetRefer<KT, jit::LayerNormTuples<T>>();
-        EXPECT_TRUE(ref != nullptr);
-        int sz = left * right;
-        std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right),
-            outref(sz);
-        RandomVec<T>(sz, x.data());
-        RandomVec<T>(left, mean.data());
-        RandomVec<T>(left, var.data());
-        RandomVec<T>(right, scale.data());
-        RandomVec<T>(right, bias.data());
-
-        const T* scale_data = scale.data();
-        const T* bias_data = bias.data();
-        T* x_data = x.data();
-        T* mean_data = mean.data();
-        T* var_data = var.data();
-        T* outref_data = outref.data();
+// test pool
+TEST(JITKernel_pool, jitcreator) {
+  const auto& jitcreators = jit::JitCodeCreatorPool::Instance().AllCreators();
+#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
+  EXPECT_EQ(jitcreators.size(), 0UL);
+#else
+  EXPECT_EQ(jitcreators.size(), 25UL);
+#endif
+}
 
-        ref(x_data, outref_data, mean_data, var_data, scale_data, bias_data,
-            left, epsilon, right);
+TEST(JITKernel_pool, jitpool) {
+  // jitpool is related with attr
+  const auto& kers = jit::JitCodePool<jit::kVAdd>().Instance().AllKernels();
+  EXPECT_EQ(kers.size(), 0UL);
+  jit::GetAllCandidateKernels<jit::VAddTuple<float>, CPUPlace>(3);
+// after call GetAllCandidateKernels, it will create jitcode Automatically
+#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
+  EXPECT_EQ(kers.size(), 0UL);
+#else
+  EXPECT_EQ(kers.size(), 1UL);
+#endif
+}
 
-        TestAllImpls<KT, jit::LayerNormTuples<T>, PlaceType, std::vector<T>,
-                     std::vector<T>, std::vector<T>, std::vector<T>,
-                     std::vector<T>, std::vector<T>, int, float>(
-            right, x, outref, mean, var, scale, bias, left, epsilon, right);
-      }
-    }
-  }
+TEST(JITKernel_pool, more) {
+  const auto& kers = jit::KernelPool::Instance().AllKernels();
+#if defined(__APPLE__) || defined(__OSX__)
+  EXPECT_EQ(kers.size(), 10UL);
+#else
+#ifdef PADDLE_WITH_MKLML
+  EXPECT_EQ(kers.size(), 21UL);
+#else
+  EXPECT_EQ(kers.size(), 8UL);
+#endif
+#endif
 }
 
-template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelCRFDecodingTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  constexpr int state_trans_base_idx = 2;
-  auto test_sizes = TestSizes();
-  test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000));
-  for (int seq_len : {1, 11, 17, 50}) {
-    for (int tag_num : test_sizes) {
-      auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>();
-      EXPECT_TRUE(ref != nullptr);
-      int x_sz = seq_len * tag_num;
-      int w_sz = (tag_num + state_trans_base_idx) * tag_num;
-      std::vector<T> x(x_sz), w(w_sz), alpharef(x_sz);
-      std::vector<int> trackref(x_sz);
-      RandomVec<T>(x_sz, x.data());
-      RandomVec<T>(w_sz, w.data());
+TEST(JITKernel_pool, refer) {
+  const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
+  EXPECT_EQ(kers.size(), 29UL);
+}
 
-      ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(),
-          trackref.data(), tag_num);
+// test helper
+TEST(JITKernel_helper, GetAllCandidateKernels) {
+  auto fp_kers =
+      jit::GetAllCandidateKernels<jit::VExpTuple<float>, CPUPlace>(10);
+#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
+  EXPECT_GE(fp_kers.size(), 1UL);  // refer
+#else
+#ifdef PADDLE_WITH_MKLML
+  EXPECT_GE(fp_kers.size(), 3UL);  // jitcode, mkl, refer
+#else
+  EXPECT_GE(fp_kers.size(), 2UL);  // jitcode, refer
+#endif
+#endif
+
+  auto db_kers =
+      jit::GetAllCandidateKernels<jit::VExpTuple<double>, CPUPlace>(10);
+#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
+  EXPECT_GE(db_kers.size(), 1UL);  // refer
+#else
+#ifdef PADDLE_WITH_MKLML
+  EXPECT_GE(db_kers.size(), 2UL);  // mkl, refer
+#else
+  EXPECT_GE(db_kers.size(), 1UL);  // refer
+#endif
+#endif
+}
 
-      TestAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType, int,
-                   std::vector<T>, std::vector<T>, std::vector<T>,
-                   std::vector<int>, int>(tag_num, seq_len, x, w, alpharef,
-                                          trackref, tag_num);
-    }
-  }
+TEST(JITKernel_helper, GetAllCandidateFuncsWithTypes) {
+  auto fp_kers =
+      jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<float>, CPUPlace>(10);
+#if defined(__APPLE__) || defined(__OSX__)
+  EXPECT_GE(fp_kers.size(), 1UL);  // refer
+#else
+#if !defined(PADDLE_WITH_MKLML) || defined(_WIN32)
+  EXPECT_GE(fp_kers.size(), 2UL);  // jitcode/mkl, refer
+#else
+  EXPECT_GE(fp_kers.size(), 3UL);  // jitcode, mkl, refer
+#endif
+#endif
+
+  auto db_kers =
+      jit::GetAllCandidateFuncsWithTypes<jit::VExpTuple<double>, CPUPlace>(10);
+#if defined(__APPLE__) || defined(__OSX__) || !defined(PADDLE_WITH_MKLML)
+  EXPECT_GE(db_kers.size(), 1UL);  // refer
+#else
+  EXPECT_GE(db_kers.size(), 2UL);  // mkl, refer
+#endif
 }
 
-template <jit::KernelType KT, typename T, typename PlaceType>
-void TestKernelVBroadcastTuples() {
-  VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
-  for (int w : TestSizes()) {
-    std::vector<T> x(w);
-    RandomVec<T>(w, x.data());
-    const T* x_data = x.data();
-    for (int64_t h : {1, 2, 6}) {
-      auto ref = jit::GetRefer<KT, jit::VBroadcastTuples<T>>();
-      EXPECT_TRUE(ref != nullptr);
-      std::vector<T> y(w * h);
-      T* y_data = y.data();
-      ref(x_data, y_data, h, w);
+TEST(JITKernel_helper, KernelFuncs) {
+  auto f1 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache().At(3);
+  auto f2 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[3];
+  EXPECT_TRUE(f1 != nullptr);
+  EXPECT_TRUE(f1 == f2);
+
+  auto f3 = jit::KernelFuncs<jit::VAddTuple<float>, CPUPlace>::Cache()[5];
+#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
+  EXPECT_TRUE(f2 == f3);
+#else
+  EXPECT_TRUE(f2 != f3);
+#endif
+}
 
-      TestAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType, std::vector<T>,
-                   std::vector<T>, int64_t>(static_cast<int64_t>(w), x, y, h,
-                                            static_cast<int64_t>(w));
-    }
+TEST(JITKernel_helper, GetAllCandidateFuncs) {
+  auto funcs = jit::GetAllCandidateFuncs<jit::VExpTuple<float>, CPUPlace>(10);
+  auto kers = jit::GetAllCandidateKernels<jit::VExpTuple<float>, CPUPlace>(10);
+  EXPECT_EQ(funcs.size(), kers.size());
+
+  std::vector<float> x(10), tgt(10);
+  RandomVec<float>(10, x.data());
+  auto best = jit::GetDefaultBestFunc<jit::VExpTuple<float>, CPUPlace>(10);
+  best(x.data(), tgt.data(), 10);
+  for (auto f : funcs) {
+    std::vector<float> y(10);
+    f(x.data(), y.data(), 10);
+    ExpectEQ<float>(y.data(), tgt.data(), 10);
   }
 }
 
-#define TEST_CPU_KERNEL(test_tuple, kernel_type)                 \
-  TEST(JITKernel, kernel_type) {                                 \
-    TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
-    TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
+TEST(JITKernel_helper, pack_weights) {
+  const int N = 8 * 60, K = 2;
+  float src[K][N], yref[K][N], y[K * N];
+  float* x = &(src[0][0]);
+  float* ref = &(yref[0][0]);
+  for (int i = 0; i < N * K; ++i) {
+    *(x + i) = static_cast<float>(i);
+  }
+  int block = 0;
+  std::vector<int> groups;
+  if (paddle::platform::MayIUse(paddle::platform::avx512f)) {
+    block = ZMM_FLOAT_BLOCK;
+    groups.push_back(30);
+  } else {
+    block = YMM_FLOAT_BLOCK;
+    groups.insert(groups.end(), {14, 14, 14, 14, 4});
   }
 
-TEST_CPU_KERNEL(XYZNTuples, kVMul);
-TEST_CPU_KERNEL(XYZNTuples, kVAdd);
-TEST_CPU_KERNEL(XYZNTuples, kVAddRelu);
-TEST_CPU_KERNEL(XYZNTuples, kVSub);
-
-TEST_CPU_KERNEL(AXYNTuples, kVScal);
-TEST_CPU_KERNEL(AXYNTuples, kVAddBias);
+  int offset = 0;
+  int acc = 0;
+  for (int g : groups) {
+    g = g * block;
+    for (int k = 0; k < K; ++k) {
+      for (int i = 0; i < g; ++i) {
+        *(ref + offset) = src[k][i + acc];
+        offset++;
+      }
+    }
+    acc += g;
+  }
 
-TEST_CPU_KERNEL(XRNTuples, kHMax);
-TEST_CPU_KERNEL(XRNTuples, kHSum);
+  jit::pack_weights<float>(x, y, N, K);
+  ExpectEQ<float>(y, ref, N * K);
+}
 
-TEST_CPU_KERNEL(XYNTuples, kVRelu);
-TEST_CPU_KERNEL(XYNTuples, kVIdentity);
-TEST_CPU_KERNEL(XYNTuples, kVSquare);
-TEST_CPU_KERNEL(XYNTuples, kVExp);
-TEST_CPU_KERNEL(XYNTuples, kVSigmoid);
-TEST_CPU_KERNEL(XYNTuples, kVTanh);
-TEST_CPU_KERNEL(XYNTuples, kVCopy);
+TEST(JITKernel_helper, attr) {
+  std::ostringstream out;
+  // KernelTypes
+  out << jit::to_string(jit::kNone) << jit::to_string(jit::kCRFDecoding)
+      << jit::to_string(jit::kEmbSeqPool) << jit::to_string(jit::kGRUH1)
+      << jit::to_string(jit::kGRUHtPart1) << jit::to_string(jit::kGRUHtPart2)
+      << jit::to_string(jit::kHSum) << jit::to_string(jit::kHMax)
+      << jit::to_string(jit::kLSTMCtHt) << jit::to_string(jit::kLSTMC1H1)
+      << jit::to_string(jit::kLayerNorm) << jit::to_string(jit::kMatMul)
+      << jit::to_string(jit::kNCHW16CMulNC) << jit::to_string(jit::kSeqPool)
+      << jit::to_string(jit::kSoftmax) << jit::to_string(jit::kVAdd)
+      << jit::to_string(jit::kVAddBias) << jit::to_string(jit::kVAddRelu)
+      << jit::to_string(jit::kVBroadcast) << jit::to_string(jit::kVCopy)
+      << jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity)
+      << jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu)
+      << jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd)
+      << jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare)
+      << jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh);
+  EXPECT_EQ(out.str().size(), 234);
+
+  // SeqPoolTypes
+  out.str("");
+  out << jit::to_string(jit::kSum) << jit::to_string(jit::kAvg)
+      << jit::to_string(jit::kSqrt);
+  EXPECT_EQ(out.str().size(), 13);
+
+  EXPECT_EQ(jit::to_kerneltype("relu"), jit::kVRelu);
+  EXPECT_EQ(jit::to_kerneltype("Identity"), jit::kVIdentity);
+  EXPECT_EQ(jit::to_kerneltype("VEXP"), jit::kVExp);
+  EXPECT_EQ(jit::to_kerneltype("SigmoiD"), jit::kVSigmoid);
+  EXPECT_EQ(jit::to_kerneltype("VTanh"), jit::kVTanh);
+
+  out.str("");
+  out << jit::lstm_attr_t(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
+  EXPECT_EQ(out.str().size(), 89);
+
+  out.str("");
+  out << jit::gru_attr_t(8, jit::kVIdentity, jit::kVSigmoid);
+  EXPECT_EQ(out.str().size(), 52);
+
+  out.str("");
+  out << jit::seq_pool_attr_t(8, jit::SeqPoolType::kSum);
+  EXPECT_EQ(out.str().size(), 44);
+
+  out.str("");
+  out << jit::emb_seq_pool_attr_t(1, 2, 3, 4, 5, jit::SeqPoolType::kAvg);
+  EXPECT_EQ(out.str().size(), 93);
+
+  out.str("");
+  out << jit::sgd_attr_t(1, 2, 3, 4, 5);
+  EXPECT_EQ(out.str().size(), 81);
+
+  out.str("");
+  out << jit::matmul_attr_t(1, 2, 3);
+  EXPECT_EQ(out.str().size(), 14);
+}
 
-TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt);
-TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1);
+// test keys
+TEST(JITKernel_key, int) {
+  EXPECT_TRUE(jit::JitCodeKey<int>(2) == jit::JitCodeKey<int>(2));
+  EXPECT_TRUE(jit::JitCodeKey<int>(2) == jit::JitCodeKey<int64_t>(2));
+  EXPECT_TRUE(jit::JitCodeKey<int>(2) != jit::JitCodeKey<int>(3));
+}
 
-TEST_CPU_KERNEL(GRUTuples, kGRUH1);
-TEST_CPU_KERNEL(GRUTuples, kGRUHtPart1);
-TEST_CPU_KERNEL(GRUTuples, kGRUHtPart2);
+TEST(JITKernel_key, gru) {
+  jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
+  jit::gru_attr_t attr2(8, jit::kVSigmoid, jit::kVTanh);
+  jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
+  jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
+  jit::gru_attr_t attr5(9, jit::kVTanh, jit::kVIdentity);
 
-TEST_CPU_KERNEL(NCHW16CMulNCTuples, kNCHW16CMulNC);
+  auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
+  auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
+  auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
+  auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
+  auto key5 = jit::JitCodeKey<jit::gru_attr_t>(attr5);
 
-TEST_CPU_KERNEL(SeqPoolTuples, kSeqPool);
-TEST_CPU_KERNEL(MatMulTuples, kMatMul);
-TEST_CPU_KERNEL(SoftmaxTuples, kSoftmax);
-TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool);
-TEST_CPU_KERNEL(SgdTuples, kSgd);
-TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm);
-TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding);
-TEST_CPU_KERNEL(VBroadcastTuples, kVBroadcast);
+  EXPECT_TRUE(key1 == key2);
+  EXPECT_TRUE(key2 != key3);
+  EXPECT_TRUE(key2 != key4);
+  EXPECT_TRUE(key2 != key5);
+  EXPECT_TRUE(key3 != key4);
+  EXPECT_TRUE(key3 != key5);
+  EXPECT_TRUE(key4 != key5);
+}
 
 TEST(JITKernel_key, lstm) {
   jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
-  jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
+  jit::lstm_attr_t attr2(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
   jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
   jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
+  jit::lstm_attr_t attr5(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true);
+  jit::lstm_attr_t attr6(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh, true);
 
   auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
   auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
   auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
   auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
+  auto key5 = jit::JitCodeKey<jit::lstm_attr_t>(attr5);
+  auto key6 = jit::JitCodeKey<jit::lstm_attr_t>(attr6);
 
-  EXPECT_TRUE(key1 != key2);
-  EXPECT_TRUE(key2 == key3);
+  EXPECT_TRUE(key1 == key2);
+  EXPECT_TRUE(key2 != key3);
+  EXPECT_TRUE(key2 != key4);
+  EXPECT_TRUE(key2 != key5);
   EXPECT_TRUE(key3 != key4);
+  EXPECT_TRUE(key3 != key5);
+  EXPECT_TRUE(key4 != key5);
+  EXPECT_TRUE(key5 == key6);
 }
 
-TEST(JITKernel_key, gru) {
-  jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
-  jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh);
-  jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
-  jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
+TEST(JITKernel_key, seq_pool) {
+  jit::seq_pool_attr_t attr1(2, jit::SeqPoolType::kSum, 1);
+  jit::seq_pool_attr_t attr2(2, jit::SeqPoolType::kSum, 3);
+  jit::seq_pool_attr_t attr3(3, jit::SeqPoolType::kSum, 3);
+  jit::seq_pool_attr_t attr4(3, jit::SeqPoolType::kAvg, 3);
 
-  auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
-  auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
-  auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
-  auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
+  auto key1 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr1);
+  auto key2 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr2);
+  auto key3 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr3);
+  auto key4 = jit::JitCodeKey<jit::seq_pool_attr_t>(attr4);
 
-  EXPECT_TRUE(key1 != key2);
+  EXPECT_TRUE(key1 == key2);
+  EXPECT_TRUE(key2 != key3);
+  EXPECT_TRUE(key2 != key4);
+  EXPECT_TRUE(key3 != key4);
+}
+
+TEST(JITKernel_key, matmul) {
+  jit::matmul_attr_t attr1(1, 2, 3);
+  jit::matmul_attr_t attr2(1, 2, 3);
+  jit::matmul_attr_t attr3(1, 3, 3);
+  jit::matmul_attr_t attr4(2, 3, 4);
+
+  auto key1 = jit::JitCodeKey<jit::matmul_attr_t>(attr1);
+  auto key2 = jit::JitCodeKey<jit::matmul_attr_t>(attr2);
+  auto key3 = jit::JitCodeKey<jit::matmul_attr_t>(attr3);
+  auto key4 = jit::JitCodeKey<jit::matmul_attr_t>(attr4);
+
+  EXPECT_TRUE(key1 == key2);
+  EXPECT_TRUE(key2 != key3);
+  EXPECT_TRUE(key2 != key4);
+  EXPECT_TRUE(key3 != key4);
+}
+
+TEST(JITKernel_key, emb_seq_pool) {
+  jit::emb_seq_pool_attr_t attr1(1, 2, 3, 4, 5, jit::SeqPoolType::kSum);
+  jit::emb_seq_pool_attr_t attr2(1, 2, 3, 4, 5, jit::SeqPoolType::kSum);
+  jit::emb_seq_pool_attr_t attr3(10, 2, 9, 8, 7, jit::SeqPoolType::kAvg);
+  jit::emb_seq_pool_attr_t attr4(10, 3, 9, 8, 7, jit::SeqPoolType::kSum);
+  jit::emb_seq_pool_attr_t attr5(1, 6, 3, 4, 5, jit::SeqPoolType::kSum);
+
+  auto key1 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr1);
+  auto key2 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr2);
+  auto key3 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr3);
+  auto key4 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr4);
+  auto key5 = jit::JitCodeKey<jit::emb_seq_pool_attr_t>(attr5);
+
+  EXPECT_TRUE(key1 == key2);
+  EXPECT_TRUE(key2 == key3);
+  EXPECT_TRUE(key2 != key4);
+  EXPECT_TRUE(key2 != key5);
+  EXPECT_TRUE(key4 != key5);
+}
+
+TEST(JITKernel_key, sgd) {
+  jit::sgd_attr_t attr1(1, 2, 3, 4, 5);
+  jit::sgd_attr_t attr2(1, 2, 3, 4, 5);
+  jit::sgd_attr_t attr3(9, 8, 7, 4, 6);
+  jit::sgd_attr_t attr4(1, 2, 3, 6, 5);
+  jit::sgd_attr_t attr5(10, 9, 8, 7, 6);
+
+  auto key1 = jit::JitCodeKey<jit::sgd_attr_t>(attr1);
+  auto key2 = jit::JitCodeKey<jit::sgd_attr_t>(attr2);
+  auto key3 = jit::JitCodeKey<jit::sgd_attr_t>(attr3);
+  auto key4 = jit::JitCodeKey<jit::sgd_attr_t>(attr4);
+  auto key5 = jit::JitCodeKey<jit::sgd_attr_t>(attr5);
+
+  EXPECT_TRUE(key1 == key2);
   EXPECT_TRUE(key2 == key3);
   EXPECT_TRUE(key3 != key4);
+  EXPECT_TRUE(key3 != key5);
+  EXPECT_TRUE(key4 != key5);
 }
-// TODO(TJ): add more test about key and pool
+
+// test kernerls
+#define TestKernelVMul TestKernelXYZN
+#define TestKernelVAdd TestKernelXYZN
+#define TestKernelVAddRelu TestKernelXYZN
+#define TestKernelVSub TestKernelXYZN
+
+#define TestKernelVScal TestKernelAXYN
+#define TestKernelVAddBias TestKernelAXYN
+
+#define TestKernelVRelu TestKernelXYN
+#define TestKernelVIdentity TestKernelXYN
+#define TestKernelVSquare TestKernelXYN
+#define TestKernelVExp TestKernelXYN
+#define TestKernelVSigmoid TestKernelXYN
+#define TestKernelVTanh TestKernelXYN
+#define TestKernelVCopy TestKernelXYN
+
+#define TestKernelHMax TestKernelXRN
+#define TestKernelHSum TestKernelXRN
+
+#define TestKernelLSTMCtHt TestKernelLSTM
+#define TestKernelLSTMC1H1 TestKernelLSTM
+
+#define TestKernelGRUH1 TestKernelGRU
+#define TestKernelGRUHtPart1 TestKernelGRU
+#define TestKernelGRUHtPart2 TestKernelGRU
+
+#define TEST_CPU_KERNEL(kernel_type)                                      \
+  TEST(JITKernel, kernel_type) {                                          \
+    TestKernel##kernel_type<jit::kernel_type##Tuple<float>, CPUPlace>();  \
+    TestKernel##kernel_type<jit::kernel_type##Tuple<double>, CPUPlace>(); \
+  }
+
+TEST_CPU_KERNEL(VMul);
+TEST_CPU_KERNEL(VAdd);
+TEST_CPU_KERNEL(VAddRelu);
+TEST_CPU_KERNEL(VSub);
+
+TEST_CPU_KERNEL(VScal);
+TEST_CPU_KERNEL(VAddBias);
+
+TEST_CPU_KERNEL(VRelu);
+TEST_CPU_KERNEL(VIdentity);
+TEST_CPU_KERNEL(VSquare);
+TEST_CPU_KERNEL(VExp);
+TEST_CPU_KERNEL(VSigmoid);
+TEST_CPU_KERNEL(VTanh);
+TEST_CPU_KERNEL(VCopy);
+
+TEST_CPU_KERNEL(HMax);
+TEST_CPU_KERNEL(HSum);
+
+TEST_CPU_KERNEL(LSTMCtHt);
+TEST_CPU_KERNEL(LSTMC1H1);
+
+TEST_CPU_KERNEL(GRUH1);
+TEST_CPU_KERNEL(GRUHtPart1);
+TEST_CPU_KERNEL(GRUHtPart2);
+
+TEST_CPU_KERNEL(NCHW16CMulNC);
+TEST_CPU_KERNEL(LayerNorm);
+TEST_CPU_KERNEL(CRFDecoding);
+
+TEST_CPU_KERNEL(SeqPool);
+TEST_CPU_KERNEL(EmbSeqPool);
+TEST_CPU_KERNEL(MatMul);
+TEST_CPU_KERNEL(Softmax);
+TEST_CPU_KERNEL(Sgd);
+TEST_CPU_KERNEL(VBroadcast);
diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h
index f564a10396..8627c83b43 100644
--- a/paddle/fluid/operators/layer_norm_op.h
+++ b/paddle/fluid/operators/layer_norm_op.h
@@ -230,8 +230,8 @@ class LayerNormKernel : public framework::OpKernel<T> {
     PADDLE_ENFORCE_EQ(bias->numel(), right);
 
     auto ker =
-        jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
-            right);
+        jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
+            .At(right);
     ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
         scale->data<T>(), bias->data<T>(), static_cast<int>(left),
         static_cast<const float>(epsilon), right);
diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h
index 0ad57c51be..66ce57594a 100644
--- a/paddle/fluid/operators/math/fc_compute.h
+++ b/paddle/fluid/operators/math/fc_compute.h
@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
     return;
   }
   if (relu) {
-    auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
-                                    platform::CPUPlace>::Cache()
-                       .At(N);
+    auto compute =
+        jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
+            N);
     for (int i = 0; i < M; i++) {
       T* dst = Y + i * N;
       compute(B, dst, dst, N);
     }
   } else {
-    auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>,
-                                    platform::CPUPlace>::Cache()
-                       .At(N);
+    auto compute =
+        jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
 #ifdef PADDLE_WITH_MKLML
 #pragma omp parallel for
 #endif
diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index 2a47502614..7af44f2b2c 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -256,8 +256,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
           static_cast<int>(input.numel() / input.dims()[0]),
           jit::SeqPoolType::kSum);
       auto seqpool =
-          jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
-              attr);
+          jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
+              .At(attr);
       for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
         attr.h = static_cast<int>(lod[i + 1] - lod[i]);
         seqpool(src, dst, &attr);
diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h
index a1cb3f9728..d77b6712c5 100644
--- a/paddle/fluid/operators/math/softmax_impl.h
+++ b/paddle/fluid/operators/math/softmax_impl.h
@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
     const int kClassDim = 1;
     // 2D data. Batch x C
     auto compute_softmax =
-        jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
-                         platform::CPUPlace>::Cache()
+        jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
             .At(in_dims[kClassDim]);
     compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
   }
diff --git a/paddle/fluid/operators/optimizers/sgd_op.h b/paddle/fluid/operators/optimizers/sgd_op.h
index c9c9f530fe..5dd5f67e00 100644
--- a/paddle/fluid/operators/optimizers/sgd_op.h
+++ b/paddle/fluid/operators/optimizers/sgd_op.h
@@ -48,7 +48,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
         T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
 
         auto sgd =
-            jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
+            jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
+                attr);
         sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
       } else if (grad_var->IsType<framework::SelectedRows>()) {
         // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
@@ -82,7 +83,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
         PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
 
         auto sgd =
-            jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
+            jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
+                attr);
         sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
       } else {
         PADDLE_THROW("Unsupported Variable Type of Grad");
diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc
index 88c968a0ea..2898a62ddb 100644
--- a/paddle/fluid/operators/recurrent_op.cc
+++ b/paddle/fluid/operators/recurrent_op.cc
@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase {
 
       // Every inputs are linked now, execute!
       executor.Run(*program, &cur_scope, block->ID(),
-                   false /*create_local_scope*/);
+                   false /*create_local_scope*/, true /*create_vars*/,
+                   std::vector<std::string>() /*skip_ref_cnt_vars*/,
+                   true /*force_disable_gc*/);
 
       // get device context from pool
       platform::DeviceContextPool &pool =
@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase {
       VLOG(5) << "Recurrent memory linking finished ";
       // Run step block with cur_scope
       executor.Run(*program, &cur_scope, block->ID(),
-                   false /*create_local_scope*/);
+                   false /*create_local_scope*/, true /*create_vars*/,
+                   std::vector<std::string>() /*skip_ref_cnt_vars*/,
+                   true /*force_disable_gc*/);
 
       VLOG(5) << "executor.Run finished ";
 
diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc
index aeabed19ab..6bbda69297 100644
--- a/paddle/fluid/pybind/imperative.cc
+++ b/paddle/fluid/pybind/imperative.cc
@@ -13,10 +13,18 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #include "paddle/fluid/pybind/imperative.h"
+
+#include <pybind11/chrono.h>
+#include <pybind11/complex.h>
+#include <pybind11/functional.h>
+#include <pybind11/stl.h>
+
 #include "paddle/fluid/framework/block_desc.h"
 #include "paddle/fluid/imperative/tracer.h"
 #include "paddle/fluid/imperative/type_defs.h"
 
+#include "paddle/fluid/pybind/pybind_boost_headers.h"
+
 namespace paddle {
 namespace pybind {
 
@@ -31,20 +39,20 @@ void BindTracer(pybind11::module* m) {
            [](imperative::Tracer& self, imperative::OpBase* op,
               const imperative::VarBasePtrMap& inputs,
               const imperative::VarBasePtrMap& outputs,
-              framework::BlockDesc* block,
+              framework::AttributeMap attrs_map,
               const platform::CPUPlace expected_place,
               const bool stop_gradient = false) {
-             return self.Trace(op, inputs, outputs, block, expected_place,
+             return self.Trace(op, inputs, outputs, attrs_map, expected_place,
                                stop_gradient);
            })
       .def("trace",
            [](imperative::Tracer& self, imperative::OpBase* op,
               const imperative::VarBasePtrMap& inputs,
               const imperative::VarBasePtrMap& outputs,
-              framework::BlockDesc* block,
+              framework::AttributeMap attrs_map,
               const platform::CUDAPlace expected_place,
               const bool stop_gradient = false) {
-             return self.Trace(op, inputs, outputs, block, expected_place,
+             return self.Trace(op, inputs, outputs, attrs_map, expected_place,
                                stop_gradient);
            })
       .def("py_trace", &imperative::Tracer::PyTrace,
diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h
index 8c48b2a715..8496cbfcb1 100644
--- a/paddle/fluid/pybind/imperative.h
+++ b/paddle/fluid/pybind/imperative.h
@@ -14,6 +14,7 @@ limitations under the License. */
 #pragma once
 
 #include <Python.h>
+#include <string>
 #include <vector>
 #include "paddle/fluid/imperative/layer.h"
 #include "pybind11/pybind11.h"
@@ -36,6 +37,8 @@ class Layer : public imperative::Layer {
 class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
  public:
   using imperative::OpBase::OpBase;  // Inherit constructors
+
+  PyOpBase(const std::string& name) : OpBase(name) {}
 };
 
 class PyVarBase : public imperative::VarBase {
diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc
index e729be4a95..7b5e417504 100644
--- a/paddle/fluid/pybind/protobuf.cc
+++ b/paddle/fluid/pybind/protobuf.cc
@@ -23,97 +23,7 @@ limitations under the License. */
 #include "paddle/fluid/framework/program_desc.h"
 #include "paddle/fluid/framework/var_desc.h"
 
-// Cast boost::variant for PyBind.
-// Copy from
-// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
-namespace pybind11 {
-namespace detail {
-
-#if !defined(PYBIND11_HIDDEN)
-#ifdef _WIN32
-#define PYBIND11_HIDDEN __declspec(dllexport)
-#else
-#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
-#endif
-#endif
-
-// Can be replaced by a generic lambda in C++14
-struct PYBIND11_HIDDEN paddle_variant_caster_visitor
-    : public boost::static_visitor<handle> {
-  return_value_policy policy;
-  handle parent;
-
-  paddle_variant_caster_visitor(return_value_policy policy, handle parent)
-      : policy(policy), parent(parent) {}
-
-  template <class T>
-  handle operator()(T const &src) const {
-    return make_caster<T>::cast(src, policy, parent);
-  }
-};
-
-template <class Variant>
-struct paddle_variant_caster;
-
-template <template <class...> class V, class... Ts>
-struct paddle_variant_caster<V<Ts...>> {
-  using Type = V<Ts...>;
-
-  template <typename T>
-  typename std::enable_if<
-      !std::is_same<T, boost::detail::variant::void_>::value, bool>::type
-  try_load(handle src, bool convert) {
-    auto caster = make_caster<T>();
-    if (!load_success_ && caster.load(src, convert)) {
-      load_success_ = true;
-
-      if (std::is_same<T, std::vector<float>>::value) {
-        auto caster_ints = make_caster<std::vector<int64_t>>();
-        if (caster_ints.load(src, convert)) {
-          VLOG(4) << "This value are floats and int64_ts satisfy "
-                     "simultaneously, will set it's type to "
-                     "std::vector<int64_t>";
-          value = cast_op<std::vector<int64_t>>(caster_ints);
-          return true;
-        }
-      }
-
-      value = cast_op<T>(caster);
-      return true;
-    }
-    return false;
-  }
-
-  template <typename T>
-  typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
-                          bool>::type
-  try_load(handle src, bool convert) {
-    return false;
-  }
-
-  bool load(handle src, bool convert) {
-    auto unused = {false, try_load<Ts>(src, convert)...};
-    (void)(unused);
-    return load_success_;
-  }
-
-  static handle cast(Type const &src, return_value_policy policy,
-                     handle parent) {
-    paddle_variant_caster_visitor visitor(policy, parent);
-    return boost::apply_visitor(visitor, src);
-  }
-
-  PYBIND11_TYPE_CASTER(Type, _("Variant"));
-  bool load_success_{false};
-};
-
-// Add specialization for concrete variant type
-template <class... Args>
-struct type_caster<boost::variant<Args...>>
-    : paddle_variant_caster<boost::variant<Args...>> {};
-
-}  // namespace detail
-}  // namespace pybind11
+#include "paddle/fluid/pybind/pybind_boost_headers.h"
 
 namespace paddle {
 namespace pybind {
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index cf59ff6d3b..395093a1f5 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -149,8 +149,14 @@ PYBIND11_MODULE(core, m) {
         []() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); });
 
   py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
-      // .def(py::init<>())
-      .def(py::init<bool>(), py::arg("stop_gradient") = false)
+      .def(
+          py::init<const std::string &, paddle::framework::proto::VarType::Type,
+                   const std::vector<int64_t>, const paddle::platform::CPUPlace,
+                   bool, bool>())
+      .def(
+          py::init<const std::string &, paddle::framework::proto::VarType::Type,
+                   const std::vector<int64_t>,
+                   const paddle::platform::CUDAPlace, bool, bool>())
       .def("_run_backward",
            [](imperative::VarBase &self) { self.RunBackward(); })
       .def("_grad_name", &imperative::VarBase::GradName)
@@ -177,51 +183,21 @@ PYBIND11_MODULE(core, m) {
            py::return_value_policy::take_ownership)
       .def("value", [](const imperative::VarBase &self) { return self.var_; },
            py::return_value_policy::reference)
-      .def_property("name",
-                    [](const imperative::VarBase &self) { return self.name_; },
-                    [](imperative::VarBase &self, const std::string &name) {
-                      self.name_ = name;
-                    })
-      .def_property("block",
-                    [](const imperative::VarBase &self) { return self.block_; },
-                    [](imperative::VarBase &self, framework::BlockDesc *block) {
-                      self.block_ = block;
-                    },
-                    py::return_value_policy::reference)
-      .def_property(
-          "persistable",
-          [](const imperative::VarBase &self) { return self.persistable_; },
-          [](imperative::VarBase &self, const bool persistable) {
-            self.persistable_ = persistable;
-          })
-      .def_property(
-          "desc",
-          [](const imperative::VarBase &self) { return self.var_desc_; },
-          [](imperative::VarBase &self, framework::VarDesc *var_desc) {
-            self.var_desc_ = var_desc;
-          },
-          py::return_value_policy::reference)
-      .def_property(
-          "stop_gradient",
-          [](const imperative::VarBase &self) { return self.IsStopGradient(); },
-          [](imperative::VarBase &self, bool stop_gradient) {
-            self.SetStopGradient(stop_gradient);
-          });
+      .def_property("name", &imperative::VarBase::Name,
+                    &imperative::VarBase::SetName)
+      .def_property_readonly("shape", &imperative::VarBase::Shape)
+      .def_property_readonly("dtype", &imperative::VarBase::DType)
+      .def_property("persistable", &imperative::VarBase::IsPersistable,
+                    &imperative::VarBase::SetPersistable)
+      .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
+                    &imperative::VarBase::SetStopGradient);
 
   py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
-      .def(py::init<>())
+      .def(py::init<const std::string &>())
       .def("register_backward_hooks",
            [](imperative::OpBase &self, const py::object &callable) {
              self.RegisterBackwardHooks(callable);
            })
-      .def_property(
-          "desc", [](const imperative::OpBase &self) { return self.op_desc_; },
-          [](imperative::OpBase &self, framework::OpDesc *op_desc) {
-            if (op_desc) {
-              self.op_desc_ = op_desc;
-            }
-          },
-          py::return_value_policy::reference)
       .def_property("_trace_id",
                     [](const imperative::OpBase &self) {
                       pybind11::gil_scoped_release release;
@@ -260,7 +236,17 @@ PYBIND11_MODULE(core, m) {
           "apply",
           [](int func_id, const std::vector<imperative::VarBase *> &inputs)
               -> std::vector<imperative::VarBase *> {
-                return imperative::PyLayer::Apply(func_id, inputs);
+                auto ret_vars = imperative::PyLayer::Apply(func_id, inputs);
+                std::vector<imperative::VarBase *> outputs;
+                outputs.reserve(ret_vars.size());
+                for (size_t i = 0U; i != ret_vars.size(); ++i) {
+                  framework::Variable *v = ret_vars[i];
+                  // TODO(minqiyang): use unique_name generator to set a name
+                  outputs.emplace_back(
+                      new imperative::VarBase("", v, nullptr, true));
+                }
+
+                return outputs;
               },
           py::return_value_policy::take_ownership)
       .def_static("register_func",
@@ -876,9 +862,11 @@ All parameter, weight, gradient are variables in Paddle.
       .def(py::init<const platform::Place &>())
       .def("close", &Executor::Close)
       .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
-                     int block_id, bool create_local_scope, bool create_vars) {
+                     int block_id, bool create_local_scope, bool create_vars,
+                     const std::vector<std::string> &fetch_vars) {
         pybind11::gil_scoped_release release;
-        self.Run(prog, scope, block_id, create_local_scope, create_vars);
+        self.Run(prog, scope, block_id, create_local_scope, create_vars,
+                 fetch_vars);
       });
 
   m.def("init_gflags", framework::InitGflags);
diff --git a/paddle/fluid/pybind/pybind_boost_headers.h b/paddle/fluid/pybind/pybind_boost_headers.h
new file mode 100644
index 0000000000..70c3136d09
--- /dev/null
+++ b/paddle/fluid/pybind/pybind_boost_headers.h
@@ -0,0 +1,115 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+#pragma once
+
+#include <Python.h>
+
+#include <vector>
+
+#include "glog/logging.h"
+#include "paddle/fluid/platform/variant.h"
+#include "pybind11/numpy.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+// Cast boost::variant for PyBind.
+// Copy from
+// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
+namespace pybind11 {
+namespace detail {
+
+#if !defined(PYBIND11_HIDDEN)
+#ifdef _WIN32
+#define PYBIND11_HIDDEN __declspec(dllexport)
+#else
+#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
+#endif
+#endif
+
+// Can be replaced by a generic lambda in C++14
+struct PYBIND11_HIDDEN paddle_variant_caster_visitor
+    : public boost::static_visitor<handle> {
+  return_value_policy policy;
+  handle parent;
+
+  paddle_variant_caster_visitor(return_value_policy policy, handle parent)
+      : policy(policy), parent(parent) {}
+
+  template <class T>
+  handle operator()(T const &src) const {
+    return make_caster<T>::cast(src, policy, parent);
+  }
+};
+
+template <class Variant>
+struct paddle_variant_caster;
+
+template <template <class...> class V, class... Ts>
+struct paddle_variant_caster<V<Ts...>> {
+  using Type = V<Ts...>;
+
+  template <typename T>
+  typename std::enable_if<
+      !std::is_same<T, boost::detail::variant::void_>::value, bool>::type
+  try_load(handle src, bool convert) {
+    auto caster = make_caster<T>();
+    if (!load_success_ && caster.load(src, convert)) {
+      load_success_ = true;
+
+      if (std::is_same<T, std::vector<float>>::value) {
+        auto caster_ints = make_caster<std::vector<int64_t>>();
+        if (caster_ints.load(src, convert)) {
+          VLOG(4) << "This value are floats and int64_ts satisfy "
+                     "simultaneously, will set it's type to "
+                     "std::vector<int64_t>";
+          value = cast_op<std::vector<int64_t>>(caster_ints);
+          return true;
+        }
+      }
+
+      value = cast_op<T>(caster);
+      return true;
+    }
+    return false;
+  }
+
+  template <typename T>
+  typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
+                          bool>::type
+  try_load(handle src, bool convert) {
+    return false;
+  }
+
+  bool load(handle src, bool convert) {
+    auto unused = {false, try_load<Ts>(src, convert)...};
+    (void)(unused);
+    return load_success_;
+  }
+
+  static handle cast(Type const &src, return_value_policy policy,
+                     handle parent) {
+    paddle_variant_caster_visitor visitor(policy, parent);
+    return boost::apply_visitor(visitor, src);
+  }
+
+  PYBIND11_TYPE_CASTER(Type, _("Variant"));
+  bool load_success_{false};
+};
+
+// Add specialization for concrete variant type
+template <class... Args>
+struct type_caster<boost::variant<Args...>>
+    : paddle_variant_caster<boost::variant<Args...>> {};
+
+}  // namespace detail
+}  // namespace pybind11
diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py
index 8102732c55..103c4d3dd0 100644
--- a/python/paddle/fluid/__init__.py
+++ b/python/paddle/fluid/__init__.py
@@ -128,11 +128,11 @@ def __bootstrap__():
         'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph',
         'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
         'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
-        'fast_eager_deletion_mode', 'allocator_strategy',
-        'reader_queue_speed_test_mode', 'print_sub_graph_dir',
-        'pe_profile_fname', 'warpctc_dir', 'inner_op_parallelism',
-        'enable_parallel_graph', 'multiple_of_cupti_buffer_size',
-        'enable_subgraph_optimize'
+        'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
+        'allocator_strategy', 'reader_queue_speed_test_mode',
+        'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
+        'inner_op_parallelism', 'enable_parallel_graph',
+        'multiple_of_cupti_buffer_size', 'enable_subgraph_optimize'
     ]
     if 'Darwin' not in sysstr:
         read_env_flags.append('use_pinned_memory')
diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py
index dfa50e721c..cc3c0dd689 100644
--- a/python/paddle/fluid/executor.py
+++ b/python/paddle/fluid/executor.py
@@ -590,7 +590,7 @@ class Executor(object):
                 fetch_var_name=fetch_var_name)
 
         self._feed_data(program, feed, feed_var_name, scope)
-        exe.run(program.desc, scope, 0, True, True)
+        exe.run(program.desc, scope, 0, True, True, fetch_var_name)
         outs = self._fetch_data(fetch_list, fetch_var_name, scope)
         if return_numpy:
             outs = as_numpy(outs)
diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py
index 7dc9178807..5b9dd86931 100644
--- a/python/paddle/fluid/framework.py
+++ b/python/paddle/fluid/framework.py
@@ -304,96 +304,101 @@ class Variable(object):
                  is_data=False,
                  **kwargs):
         self.block = block
-        self.error_clip = error_clip
-
         if name is None:
             name = unique_name.generate('_generated_var')
-        is_new_var = False
-        name = cpt.to_text(name)
-        self.desc = self.block.desc.find_var(cpt.to_bytes(name))
 
-        if self.desc is None:
-            self.desc = self.block.desc.var(cpt.to_bytes(name))
-            is_new_var = True
-
-        if is_new_var:
-            self.desc.set_type(type)
-        elif self.desc.type() != type:
-            raise ValueError("Variable {0} has been created before. The "
-                             "previous type is {1}; the new type is {2}. They"
-                             " are not matched".format(self.name,
-                                                       self.desc.type(), type))
-
-        if shape is not None:
-            if is_new_var:
-                self.desc.set_shape(shape)
-            else:
-                old_shape = self.shape
-                shape = tuple(shape)
-                if shape != old_shape:
-                    raise ValueError(
-                        "Variable {0} has been created before. the previous "
-                        "shape is {1}; the new shape is {2}. They are not "
-                        "matched.".format(self.name, old_shape, shape))
         if dtype is not None:
             if not isinstance(dtype, core.VarDesc.VarType):
                 dtype = convert_np_dtype_to_dtype_(dtype)
-            if is_new_var:
-                self.desc.set_dtype(dtype)
-            else:
-                old_dtype = self.dtype
-                if dtype != old_dtype:
-                    raise ValueError("Variable {0} has been created before. "
-                                     "The previous data type is {1}; the new "
-                                     "data type is {2}. They are not "
-                                     "matched.".format(self.name, old_dtype,
-                                                       dtype))
-
-        if lod_level is not None:
-            if is_new_var:
-                self.desc.set_lod_level(lod_level)
-            else:
-                if lod_level != self.lod_level:
-                    raise ValueError("Variable {0} has been created before. "
-                                     "The previous lod_level is {1}; the new "
-                                     "lod_level is {2}. They are not "
-                                     "matched".format(self.name, self.lod_level,
-                                                      lod_level))
-        if persistable is not None:
-            if is_new_var:
-                self.desc.set_persistable(persistable)
-            else:
-                if persistable != self.persistable:
-                    raise ValueError(
-                        "Variable {0} has been created before."
-                        "The previous persistable is {1}; the new "
-                        "persistable is {2}. They are not matched".format(
-                            self.name, self.persistable, persistable))
-
-        if capacity is not None:
-            if is_new_var:
-                self.desc.set_capacity(capacity)
-            else:
-                # TODO(abhinavarora) : Compare with set capacity once,
-                # get_capacity is implemented
-                pass
 
         if _in_imperative_mode():
             # record vars in tracer rather than blocks
             self._ivar = kwargs.get("ivar", None)
             if not self._ivar:
-                self._ivar = core.VarBase(stop_gradient)
-            self._ivar.desc = self.desc
-            self._ivar.block = block.desc
-            self._ivar.name = name
-            self._ivar.persistable = persistable
+                self._ivar = core.VarBase(
+                    name, dtype if dtype else core.VarDesc.VarType.FP32,
+                    list(shape) if shape else [],
+                    _current_expected_place(), True
+                    if persistable else False, stop_gradient)
             if persistable:
-                self.block.vars[name] = self
+                _imperative_tracer().trace_var(name, self)
         else:
+            self.error_clip = error_clip
+
+            is_new_var = False
+            name = cpt.to_text(name)
+            self.desc = self.block.desc.find_var(cpt.to_bytes(name))
+
+            if self.desc is None:
+                self.desc = self.block.desc.var(cpt.to_bytes(name))
+                is_new_var = True
+
+            if is_new_var:
+                self.desc.set_type(type)
+            elif self.desc.type() != type:
+                raise ValueError(
+                    "Variable {0} has been created before. The "
+                    "previous type is {1}; the new type is {2}. They"
+                    " are not matched".format(self.name, self.desc.type(),
+                                              type))
+
+            if shape is not None:
+                if is_new_var:
+                    self.desc.set_shape(shape)
+                else:
+                    old_shape = self.shape
+                    shape = tuple(shape)
+                    if shape != old_shape:
+                        raise ValueError(
+                            "Variable {0} has been created before. the previous "
+                            "shape is {1}; the new shape is {2}. They are not "
+                            "matched.".format(self.name, old_shape, shape))
+            if dtype is not None:
+                if is_new_var:
+                    self.desc.set_dtype(dtype)
+                else:
+                    old_dtype = self.dtype
+                    if dtype != old_dtype:
+                        raise ValueError(
+                            "Variable {0} has been created before. "
+                            "The previous data type is {1}; the new "
+                            "data type is {2}. They are not "
+                            "matched.".format(self.name, old_dtype, dtype))
+
+            if lod_level is not None:
+                if is_new_var:
+                    self.desc.set_lod_level(lod_level)
+                else:
+                    if lod_level != self.lod_level:
+                        raise ValueError(
+                            "Variable {0} has been created before. "
+                            "The previous lod_level is {1}; the new "
+                            "lod_level is {2}. They are not "
+                            "matched".format(self.name, self.lod_level,
+                                             lod_level))
+            if persistable is not None:
+                if is_new_var:
+                    self.desc.set_persistable(persistable)
+                else:
+                    if persistable != self.persistable:
+                        raise ValueError(
+                            "Variable {0} has been created before."
+                            "The previous persistable is {1}; the new "
+                            "persistable is {2}. They are not matched".format(
+                                self.name, self.persistable, persistable))
+
+            if capacity is not None:
+                if is_new_var:
+                    self.desc.set_capacity(capacity)
+                else:
+                    # TODO(abhinavarora) : Compare with set capacity once,
+                    # get_capacity is implemented
+                    pass
+
             self.block.vars[name] = self
-        self.op = None
-        self.stop_gradient = stop_gradient
-        self.is_data = is_data
+            self.op = None
+            self.stop_gradient = stop_gradient
+            self.is_data = is_data
 
     def _numpy(self):
         new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
@@ -462,40 +467,63 @@ class Variable(object):
     def _stop_gradient(self, s):
         if _in_imperative_mode():
             self._ivar.stop_gradient = s
-        self.stop_gradient = s
+        else:
+            self.stop_gradient = s
 
     @property
     def persistable(self):
-        return self.desc.persistable()
+        if _in_imperative_mode():
+            return self._ivar.persistable
+        else:
+            return self.desc.persistable()
 
     @persistable.setter
     def persistable(self, p):
-        self.desc.set_persistable(p)
+        if _in_imperative_mode():
+            return self._ivar.persistable
+        else:
+            self.desc.set_persistable(p)
 
     @property
     def name(self):
-        return cpt.to_text(self.desc.name())
+        if _in_imperative_mode():
+            return self._ivar.name
+        else:
+            return cpt.to_text(self.desc.name())
 
     @name.setter
     def name(self, new_name):
-        self.desc.set_name(new_name)
+        if _in_imperative_mode():
+            self._ivar.name = new_name
+        else:
+            self.desc.set_name(new_name)
 
     @property
     def shape(self):
         # convert to tuple, make it as same as numpy API.
-        return tuple(self.desc.shape())
+        if _in_imperative_mode():
+            return self._ivar.shape
+        else:
+            return tuple(self.desc.shape())
 
     @property
     def dtype(self):
-        return self.desc.dtype()
+        if _in_imperative_mode():
+            return self._ivar.dtype
+        else:
+            return self.desc.dtype()
 
     @property
     def lod_level(self):
+        # TODO(minqiyang): Support lod_level in imperative mode
         return self.desc.lod_level()
 
     @property
     def type(self):
-        return self.desc.type()
+        if _in_imperative_mode():
+            return self._ivar.dtype
+        else:
+            return self.desc.type()
 
     def _set_error_clip(self, error_clip):
         """
@@ -624,121 +652,14 @@ class Operator(object):
                  inputs=None,
                  outputs=None,
                  attrs=None):
-        self.block = block
-        self.desc = desc
-        # note: not add self.attrs here:
-        # https://github.com/PaddlePaddle/Paddle/pull/12583#pullrequestreview-145093173
-        op_attrs = attrs
-        if op_attrs is None:
-            op_attrs = dict()
-        del attrs
-
-        op_maker = core.op_proto_and_checker_maker
-
-        if op_maker.kOpRoleAttrName() not in op_attrs:
-            op_attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
-
-        role_var_name = op_maker.kOpRoleVarAttrName()
-        if len(self.block.program.
-               op_role_var) != 0 and role_var_name not in op_attrs:
-            op_attrs[role_var_name] = self.block.program.op_role_var
-
-        if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
-            del op_attrs[role_var_name]
-
-        if len(self.desc.type()) != 0:
-            return
-        if type is None:
-            raise ValueError(
-                "`type` to initilized an Operator can not be None.")
-        else:
-            callstack_var_name = op_maker.kOpCreationCallstackAttrName()
-            op_attrs[callstack_var_name] = list(
-                reversed(traceback.format_stack()))[1:]
-
-        self.desc.set_type(type)
-        proto = OpProtoHolder.instance().get_op_proto(type)
-
-        namescope_var_name = op_maker.kOpNameScopeAttrName()
-        op_attrs[namescope_var_name] = _full_name_scope()
-
-        def find_name(var_list, name):
-            for var_name in var_list:
-                if var_list[var_name] is not None and var_name == name:
-                    return True
-            return False
-
-        if inputs is not None:
-            for in_proto in proto.inputs:
-                found = find_name(inputs, in_proto.name)
-                assert found or in_proto.dispensable, "Input {} not found".format(
-                    in_proto.name)
-
-                if found:
-                    in_args = inputs[in_proto.name]
-                    if not isinstance(in_args, list):
-                        in_args = [in_args]
-                    if not in_proto.duplicable and len(in_args) > 1:
-                        raise ValueError(
-                            "Input %s expects only one input, but %d are given."
-                            % (in_proto.name, len(in_args)))
-                    in_arg_names = []
-                    for arg in in_args:
-                        if isinstance(arg, six.string_types):
-                            in_arg_names.append(arg)
-                        elif isinstance(arg, six.binary_type):
-                            in_arg_names.append(arg.decode())
-                        else:
-                            in_arg_names.append(cpt.to_text(arg.name))
-                    self.desc.set_input(in_proto.name, in_arg_names)
-                else:
-                    self.desc.set_input(in_proto.name, [])
-
-        if outputs is not None:
-            for m in proto.outputs:
-                if (m.name not in outputs) and m.dispensable:
-                    continue
-                if not ((m.name in outputs) or m.dispensable):
-                    raise ValueError(
-                        ("Incorrect setting for output(s) of "
-                         "operator \"%s\", should set: [%s].") % (type, m.name))
-            for out_proto in proto.outputs:
-                if out_proto.name not in outputs:
-                    continue
-                out_args = outputs[out_proto.name]
-                if not isinstance(out_args, list):
-                    out_args = [out_args]
-                if not out_proto.duplicable and len(out_args) > 1:
-                    raise ValueError(
-                        "Output %s expects only one output, but %d are given." %
-                        (out_proto.name, len(out_args)))
-                out_arg_names = []
-                for arg in out_args:
-                    out_arg_names.append(cpt.to_text(arg.name))
-                    # TODO(minqiyang): could we remove variable's op in static mode?
-                    if not _in_imperative_mode():
-                        arg.op = self
-                self.desc.set_output(out_proto.name, out_arg_names)
-
-        if op_attrs is not None:
-            if not isinstance(op_attrs, dict):
-                raise TypeError("'attrs' should be a dict.")
-            for attr in proto.attrs:
-                attr_name = attr.name
-                if (attr_name not in op_attrs) or (op_attrs[attr_name] is None):
-                    continue
-                attr_val = op_attrs[attr_name]
-                self._update_desc_attr(attr_name, attr_val)
-
-        self.desc.check_attrs()
-        if self._has_kernel(type):
-            self.desc.infer_var_type(self.block.desc)
-            self.desc.infer_shape(self.block.desc)
-
         if _in_imperative_mode():
-            self.iop = core.OpBase()
-            self.iop.desc = self.desc
+            if type is None:
+                raise ValueError(
+                    "`type` to initilized an Operator can not be None.")
+            self.iop = core.OpBase(type)
 
+            # TODO(minqiyang): remove these lines after we take apart all
+            # backward grads and forward variables
             self.inputs = defaultdict(list)
             if inputs is not None:
                 for k, v in six.iteritems(inputs):
@@ -755,6 +676,121 @@ class Operator(object):
                     elif isinstance(v, list) or isinstance(v, tuple):
                         self.outputs[k].extend([var._ivar for var in v])
 
+            self.attrs = attrs if attrs else {}
+        else:
+            self.block = block
+            self.desc = desc
+            # note: not add self.attrs here:
+            # https://github.com/PaddlePaddle/Paddle/pull/12583#pullrequestreview-145093173
+            op_attrs = attrs
+            if op_attrs is None:
+                op_attrs = dict()
+            del attrs
+
+            op_maker = core.op_proto_and_checker_maker
+
+            if op_maker.kOpRoleAttrName() not in op_attrs:
+                op_attrs[op_maker.kOpRoleAttrName(
+                )] = self.block.program.op_role
+
+            role_var_name = op_maker.kOpRoleVarAttrName()
+            if len(self.block.program.
+                   op_role_var) != 0 and role_var_name not in op_attrs:
+                op_attrs[role_var_name] = self.block.program.op_role_var
+
+            if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
+                del op_attrs[role_var_name]
+
+            if len(self.desc.type()) != 0:
+                return
+            if type is None:
+                raise ValueError(
+                    "`type` to initilized an Operator can not be None.")
+            else:
+                callstack_var_name = op_maker.kOpCreationCallstackAttrName()
+                op_attrs[callstack_var_name] = list(
+                    reversed(traceback.format_stack()))[1:]
+
+            self.desc.set_type(type)
+            proto = OpProtoHolder.instance().get_op_proto(type)
+
+            namescope_var_name = op_maker.kOpNameScopeAttrName()
+            op_attrs[namescope_var_name] = _full_name_scope()
+
+            def find_name(var_list, name):
+                for var_name in var_list:
+                    if var_list[var_name] is not None and var_name == name:
+                        return True
+                return False
+
+            if inputs is not None:
+                for in_proto in proto.inputs:
+                    found = find_name(inputs, in_proto.name)
+                    assert found or in_proto.dispensable, "Input {} not found".format(
+                        in_proto.name)
+
+                    if found:
+                        in_args = inputs[in_proto.name]
+                        if not isinstance(in_args, list):
+                            in_args = [in_args]
+                        if not in_proto.duplicable and len(in_args) > 1:
+                            raise ValueError(
+                                "Input %s expects only one input, but %d are given."
+                                % (in_proto.name, len(in_args)))
+                        in_arg_names = []
+                        for arg in in_args:
+                            if isinstance(arg, six.string_types):
+                                in_arg_names.append(arg)
+                            elif isinstance(arg, six.binary_type):
+                                in_arg_names.append(arg.decode())
+                            else:
+                                in_arg_names.append(cpt.to_text(arg.name))
+                        self.desc.set_input(in_proto.name, in_arg_names)
+                    else:
+                        self.desc.set_input(in_proto.name, [])
+
+            if outputs is not None:
+                for m in proto.outputs:
+                    if (m.name not in outputs) and m.dispensable:
+                        continue
+                    if not ((m.name in outputs) or m.dispensable):
+                        raise ValueError(("Incorrect setting for output(s) of "
+                                          "operator \"%s\", should set: [%s].")
+                                         % (type, m.name))
+                for out_proto in proto.outputs:
+                    if out_proto.name not in outputs:
+                        continue
+                    out_args = outputs[out_proto.name]
+                    if not isinstance(out_args, list):
+                        out_args = [out_args]
+                    if not out_proto.duplicable and len(out_args) > 1:
+                        raise ValueError(
+                            "Output %s expects only one output, but %d are given."
+                            % (out_proto.name, len(out_args)))
+                    out_arg_names = []
+                    for arg in out_args:
+                        out_arg_names.append(cpt.to_text(arg.name))
+                        # TODO(minqiyang): could we remove variable's op in static mode?
+                        if not _in_imperative_mode():
+                            arg.op = self
+                    self.desc.set_output(out_proto.name, out_arg_names)
+
+            if op_attrs is not None:
+                if not isinstance(op_attrs, dict):
+                    raise TypeError("'attrs' should be a dict.")
+                for attr in proto.attrs:
+                    attr_name = attr.name
+                    if (attr_name not in op_attrs) or (
+                            op_attrs[attr_name] is None):
+                        continue
+                    attr_val = op_attrs[attr_name]
+                    self._update_desc_attr(attr_name, attr_val)
+
+            self.desc.check_attrs()
+            if self._has_kernel(type):
+                self.desc.infer_var_type(self.block.desc)
+                self.desc.infer_shape(self.block.desc)
+
     def _has_kernel(self, op_type):
         return op_type not in self.OP_WITHOUT_KERNEL_SET
 
@@ -1318,16 +1354,15 @@ class Block(object):
         Returns:
             Operator: the append Operator.
         """
-        op_desc = self.desc.append_op()
-        op = Operator(
-            block=self,
-            desc=op_desc,
-            type=kwargs.get("type", None),
-            inputs=kwargs.get("inputs", None),
-            outputs=kwargs.get("outputs", None),
-            attrs=kwargs.get("attrs", None))
-
         if _in_imperative_mode():
+            op = Operator(
+                block=self,
+                desc=None,
+                type=kwargs.get("type", None),
+                inputs=kwargs.get("inputs", None),
+                outputs=kwargs.get("outputs", None),
+                attrs=kwargs.get("attrs", None))
+
             # record ops in tracer rather than blocks
             #
             # TODO(minqiyang): add op stop_gradient support in static mode too.
@@ -1335,6 +1370,15 @@ class Block(object):
             _imperative_tracer().trace_op(op,
                                           kwargs.get("stop_gradient", False))
         else:
+            op_desc = self.desc.append_op()
+            op = Operator(
+                block=self,
+                desc=op_desc,
+                type=kwargs.get("type", None),
+                inputs=kwargs.get("inputs", None),
+                outputs=kwargs.get("outputs", None),
+                attrs=kwargs.get("attrs", None))
+
             self.ops.append(op)
 
         return op
@@ -1383,19 +1427,27 @@ class Block(object):
         return self.ops[start:end]
 
     def _prepend_op(self, *args, **kwargs):
-        op_desc = self.desc._prepend_op()
-        op = Operator(
-            self,
-            op_desc,
-            type=kwargs.get("type", None),
-            inputs=kwargs.get("inputs", None),
-            outputs=kwargs.get("outputs", None),
-            attrs=kwargs.get("attrs", None))
         if _in_imperative_mode():
+            op = Operator(
+                self,
+                None,
+                type=kwargs.get("type", None),
+                inputs=kwargs.get("inputs", None),
+                outputs=kwargs.get("outputs", None),
+                attrs=kwargs.get("attrs", None))
             _imperative_tracer().trace_op(op,
                                           kwargs.get("stop_gradient", False))
         else:
+            op_desc = self.desc._prepend_op()
+            op = Operator(
+                self,
+                op_desc,
+                type=kwargs.get("type", None),
+                inputs=kwargs.get("inputs", None),
+                outputs=kwargs.get("outputs", None),
+                attrs=kwargs.get("attrs", None))
             self.ops.insert(0, op)
+
         return op
 
     def _sync_with_cpp(self):
diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py
index 0c96d4dc59..71d169a7dc 100644
--- a/python/paddle/fluid/imperative/layers.py
+++ b/python/paddle/fluid/imperative/layers.py
@@ -258,7 +258,7 @@ class PyLayer(core.PyLayer):
             cls.backward_id = core.PyLayer.num_funcs() + 1
             PyLayer.register_func(cls.backward_id, cls._do_backward)
 
-        iop = core.OpBase()
+        iop = core.OpBase(cls.__class__.__name__ + str(cls.forward_id))
         iop.forward_id = cls.forward_id
         iop.backward_id = cls.backward_id
         block.ops.append(iop)
diff --git a/python/paddle/fluid/imperative/tracer.py b/python/paddle/fluid/imperative/tracer.py
index 1064ad63e7..bd77de7424 100644
--- a/python/paddle/fluid/imperative/tracer.py
+++ b/python/paddle/fluid/imperative/tracer.py
@@ -36,14 +36,21 @@ class Tracer(core.Tracer):
         super(Tracer, self).__init__(block)
 
         self._ops = defaultdict()
+        self._vars = defaultdict()
         self._trace_id = 0
 
+    def trace_var(self, name, var):
+        self._vars[name] = var
+
+    def all_parameters(self):
+        return list((item for name, item in six.iteritems(self._vars)
+                     if isinstance(item, framework.Parameter)))
+
     def trace_op(self, op, stop_gradient=False):
         # record op's trace id
         op.iop._trace_id = self._trace_id
 
-        # trace op and save it
-        backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc,
+        backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
                                    framework._current_expected_place(),
                                    stop_gradient)
 
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index 9d1d5fe093..d0bff52e43 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
 
     similarity_matrix = matmul(
         anchor, positive, transpose_x=False, transpose_y=True)
-    softmax_value = softmax(similarity_matrix)
-    cross_entropy = -1 * reduce_sum(labels * log(softmax_value), 0)
+    softmax_ce = softmax_with_cross_entropy(
+        logits=similarity_matrix, label=labels, soft_label=True)
+    cross_entropy = reduce_sum(labels * softmax_ce, 0)
     celoss = reduce_mean(cross_entropy)
 
     return l2loss + celoss
diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py
index 86b7716664..d501d02bd4 100644
--- a/python/paddle/fluid/optimizer.py
+++ b/python/paddle/fluid/optimizer.py
@@ -377,17 +377,16 @@ class Optimizer(object):
             and list of (param, grad) Variables pair for optimization.
         """
         self._dtype = loss.dtype
-        program = loss.block.program
         optimize_ops = []
         if framework._in_imperative_mode():
             if parameter_list is not None:
                 parameters = parameter_list
             else:
-                parameters = program.global_block().all_parameters()
+                parameters = framework._imperative_tracer().all_parameters()
 
             params_grads = []
             for param in parameters:
-                if param.stop_gradient or not param.trainable:
+                if not param.trainable:
                     continue
                 # create gradient variable
                 grad_var = Variable(
@@ -396,9 +395,11 @@ class Optimizer(object):
                     stop_gradient=True,
                     ivar=param._ivar._grad_ivar())
                 params_grads.append((param, grad_var))
-            with program_guard(program, startup_program):
+            with program_guard(framework.default_main_program(),
+                               framework.default_startup_program()):
                 optimize_ops = self._create_optimization_pass(params_grads)
         else:
+            program = loss.block.program
             with program_guard(program, startup_program):
                 params_grads = self.backward(loss, startup_program,
                                              parameter_list, no_grad_set)
diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py
index 603c8e7488..05cc41b96f 100644
--- a/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py
+++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py
@@ -16,8 +16,7 @@ import os
 import unittest
 os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
 
-os.environ[
-    'RECORDIO_FILENAME'] = '/tmp/eager_deletion_transformer.wmt16.recordio'
+os.environ['RECORDIO_FILENAME'] = './eager_deletion_transformer.wmt16.recordio'
 
 from test_parallel_executor_transformer import TestTransformer
 
diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py
new file mode 100644
index 0000000000..898d04ebe1
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import os
+os.environ['CPU_NUM'] = '2'
+os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
+os.environ['FLAGS_fast_eager_deletion_mode'] = '1'
+
+import unittest
+import paddle.fluid as fluid
+import paddle.fluid.layers as layers
+from paddle.fluid.executor import Executor
+import paddle.fluid.core as core
+from paddle.fluid.backward import append_backward
+import paddle.fluid.compiler as compiler
+import numpy
+import multiprocessing
+
+
+class TestEagerDeletionWhileOpBase(unittest.TestCase):
+    def test_main(self):
+        places = [core.CPUPlace(), ]
+        if core.is_compiled_with_cuda():
+            places.append(core.CUDAPlace(0))
+
+        for p in places:
+            for with_data_parallel in [False, True]:
+                with fluid.program_guard(fluid.Program(), fluid.Program()):
+                    with fluid.scope_guard(fluid.Scope()):
+                        self.run_main(p, with_data_parallel)
+
+    def run_main(self, place, with_data_parallel):
+        self.place = place
+        self.with_data_parallel = with_data_parallel
+
+        if not core.is_compiled_with_cuda() and isinstance(self.place,
+                                                           core.CUDAPlace):
+            return
+
+        if isinstance(self.place, core.CUDAPlace):
+            device_cnt = core.get_cuda_device_count(
+            ) if self.with_data_parallel else 1
+        else:
+            device_cnt = int(
+                os.environ.get('CPU_NUM', multiprocessing.cpu_count(
+                ))) if self.with_data_parallel else 1
+
+        d0 = layers.data(
+            "d0", shape=[10], append_batch_size=False, dtype='float32')
+        d1 = layers.data(
+            "d1", shape=[10], append_batch_size=False, dtype='float32')
+        d2 = layers.data(
+            "d2", shape=[10], append_batch_size=False, dtype='float32')
+
+        i = layers.zeros(shape=[1], dtype='int64')
+        i.stop_gradient = True
+
+        init = layers.zeros(shape=[10], dtype='float32')
+        mem_array = layers.array_write(x=init, i=i)
+        data_array = layers.array_write(x=d0, i=i)
+
+        i = layers.increment(i)
+        layers.array_write(d1, i, array=data_array)
+
+        i = layers.increment(i)
+        layers.array_write(d2, i, array=data_array)
+
+        i = layers.zeros(shape=[1], dtype='int64')
+        i.stop_gradient = True
+
+        array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
+        array_len.stop_gradient = True
+        cond = layers.less_than(x=i, y=array_len)
+
+        j = layers.fill_constant(shape=[1], dtype='int64', value=1)
+        j.stop_gradient = True
+
+        array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
+        array_len2.stop_gradient = True
+        cond2 = layers.less_than(x=j, y=array_len2)
+
+        while_op = layers.While(cond=cond)
+        while_op2 = layers.While(cond=cond2)
+        with while_op.block():
+            d = layers.array_read(array=data_array, i=i)
+            prev = layers.array_read(array=mem_array, i=i)
+            d = layers.reshape(d, shape=[10])
+            prev = layers.reshape(prev, shape=[10])
+            result = layers.sums(input=[d, prev])
+
+            i = layers.increment(x=i, in_place=True)
+            layers.array_write(result, i=i, array=mem_array)
+            layers.less_than(x=i, y=array_len, cond=cond)
+            with while_op2.block():
+                d2 = layers.array_read(array=data_array, i=j)
+                prev2 = layers.array_read(array=mem_array, i=j)
+                d2 = layers.reshape(d2, shape=[10])
+                prev2 = layers.reshape(prev2, shape=[10])
+                result2 = layers.sums(input=[d2, prev2])
+
+                j = layers.increment(x=j, in_place=True)
+                layers.array_write(result2, i=j, array=mem_array)
+                layers.less_than(x=j, y=array_len2, cond=cond2)
+
+        sum_result = layers.array_read(array=mem_array, i=j)
+        sum_result.persistable = True
+        tmp = layers.unsqueeze(sum_result, axes=[0])
+        tmp = layers.expand(tmp, expand_times=[10, 1])
+        fc = layers.fc(tmp, size=256)
+        loss = layers.mean(sum_result)
+
+        optim = fluid.optimizer.Adam(learning_rate=1e-3)
+        optim.minimize(loss)
+
+        exe = Executor(self.place)
+        exe.run(fluid.default_startup_program())
+
+        prog = compiler.CompiledProgram(fluid.default_main_program())
+        if self.with_data_parallel:
+            prog = prog.with_data_parallel()
+
+        for _ in range(5):
+            d = []
+            for i in range(3):
+                tmp = numpy.random.random(size=[10]).astype('float32')
+                if not self.with_data_parallel:
+                    d.append(tmp)
+                else:
+                    d.append(numpy.array([tmp] * device_cnt))
+
+            outs = exe.run(program=prog,
+                           feed={'d0': d[0],
+                                 'd1': d[1],
+                                 'd2': d[2]},
+                           fetch_list=[sum_result])
+            self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py
index 97fc1eab3d..4c44195a3d 100644
--- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py
+++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py
@@ -152,7 +152,7 @@ class SimpleRNNCell(fluid.imperative.Layer):
             type='reduce_sum',
             inputs={'X': softmax_out},
             outputs={'Out': reduce_out},
-            attrs={'dim': None,
+            attrs={'dim': [],
                    'keep_dim': False,
                    'reduce_all': True})
 
diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py
index 94ac393315..ab9298890b 100644
--- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py
+++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py
@@ -277,7 +277,7 @@ class TestImperativeResnet(unittest.TestCase):
 
                 dy_grad_value = {}
                 for param in resnet.parameters():
-                    if not param.stop_gradient:
+                    if param.trainable:
                         np_array = np.array(param._ivar._grad_ivar().value()
                                             .get_tensor())
                         dy_grad_value[param.name + core.grad_var_suffix(
@@ -322,7 +322,7 @@ class TestImperativeResnet(unittest.TestCase):
             for param in resnet.parameters():
                 static_param_name_list.append(param.name)
             for param in resnet.parameters():
-                if not param.stop_gradient:
+                if param.trainable:
                     static_grad_name_list.append(param.name +
                                                  core.grad_var_suffix())
 
diff --git a/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py b/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py
new file mode 100644
index 0000000000..7607189454
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
+os.environ['FLAGS_memory_fraction_of_eager_deletion'] = "0.55"
+
+os.environ['RECORDIO_FILENAME'] = './p_gc_transformer.wmt16.recordio'
+
+from test_parallel_executor_transformer import TestTransformer
+
+if __name__ == '__main__':
+    unittest.main()