diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py
index 0d5c9652de..9540900b11 100644
--- a/benchmark/fluid/args.py
+++ b/benchmark/fluid/args.py
@@ -136,10 +136,6 @@ def parse_args():
         '--no_random',
         action='store_true',
         help='If set, keep the random seed and do not shuffle the data.')
-    parser.add_argument(
-        '--use_lars',
-        action='store_true',
-        help='If set, use lars for optimizers, ONLY support resnet module.')
     parser.add_argument(
         '--reduce_strategy',
         type=str,
diff --git a/benchmark/fluid/models/resnet.py b/benchmark/fluid/models/resnet.py
index 1b3bfe659c..f692e7722a 100644
--- a/benchmark/fluid/models/resnet.py
+++ b/benchmark/fluid/models/resnet.py
@@ -200,11 +200,6 @@ def get_model(args, is_train, main_prog, startup_prog):
             # configure optimize
             optimizer = None
             if is_train:
-                if args.use_lars:
-                    lars_decay = 1.0
-                else:
-                    lars_decay = 0.0
-
                 total_images = 1281167 / trainer_count
 
                 step = int(total_images / (args.batch_size * args.gpus) + 1)
diff --git a/benchmark/fluid/models/resnet_with_preprocess.py b/benchmark/fluid/models/resnet_with_preprocess.py
index e8d661d847..e996c9a704 100644
--- a/benchmark/fluid/models/resnet_with_preprocess.py
+++ b/benchmark/fluid/models/resnet_with_preprocess.py
@@ -224,11 +224,6 @@ def get_model(args, is_train, main_prog, startup_prog):
             # configure optimize
             optimizer = None
             if is_train:
-                if args.use_lars:
-                    lars_decay = 1.0
-                else:
-                    lars_decay = 0.0
-
                 total_images = 1281167 / trainer_count
 
                 step = int(total_images / args.batch_size + 1)
diff --git a/benchmark/fluid/models/se_resnext.py b/benchmark/fluid/models/se_resnext.py
index 9f887fb324..7fbb83c2ec 100644
--- a/benchmark/fluid/models/se_resnext.py
+++ b/benchmark/fluid/models/se_resnext.py
@@ -244,11 +244,6 @@ def get_model(args, is_train, main_prog, startup_prog):
 
             optimizer = None
             if is_train:
-                if args.use_lars:
-                    lars_decay = 1.0
-                else:
-                    lars_decay = 0.0
-
                 total_images = 1281167 / trainer_count
 
                 step = int(total_images / args.batch_size + 1)
@@ -262,8 +257,7 @@ def get_model(args, is_train, main_prog, startup_prog):
                     learning_rate=fluid.layers.piecewise_decay(
                         boundaries=bd, values=lr),
                     momentum=0.9,
-                    regularization=fluid.regularizer.L2Decay(1e-4),
-                    LARS_weight_decay=lars_decay)
+                    regularization=fluid.regularizer.L2Decay(1e-4))
                 optimizer.minimize(avg_cost)
 
                 if args.memory_optimize:
diff --git a/doc/README.md b/doc/README.md
new file mode 100644
index 0000000000..77aa2a5322
--- /dev/null
+++ b/doc/README.md
@@ -0,0 +1,7 @@
+# For Readers and Developers
+
+Thanks for reading PaddlePaddle documentation. 
+
+Since **September 17th, 2018**, the **0.15.0 and develop** documentation source has been moved to [Fluiddoc Repo](https://github.com/PaddlePaddle/Paddle) and updated in Fluiddoc Repo.
+
+Please turn to Fluiddoc Repo for the latest documentation.
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index b424ca529e..50f6525e1e 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -73,7 +73,6 @@ paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program',
 paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
 paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
 paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
 paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
 paddle.fluid.initializer.NormalInitializer.__init__ ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0))
@@ -296,6 +295,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', '
 paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
 paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True))
 paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
+paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,))
 paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
 paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
 paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
@@ -350,25 +350,25 @@ paddle.fluid.nets.simple_img_conv_pool ArgSpec(args=['input', 'num_filters', 'fi
 paddle.fluid.nets.sequence_conv_pool ArgSpec(args=['input', 'num_filters', 'filter_size', 'param_attr', 'act', 'pool_type'], varargs=None, keywords=None, defaults=(None, 'sigmoid', 'max'))
 paddle.fluid.nets.glu ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,))
 paddle.fluid.nets.scaled_dot_product_attention ArgSpec(args=['queries', 'keys', 'values', 'num_heads', 'dropout_rate'], varargs=None, keywords=None, defaults=(1, 0.0))
-paddle.fluid.optimizer.SGDOptimizer.__init__ ArgSpec(args=['self', 'learning_rate'], varargs=None, keywords='kwargs', defaults=None)
+paddle.fluid.optimizer.SGDOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'regularization', 'name'], varargs=None, keywords=None, defaults=(None, None))
 paddle.fluid.optimizer.SGDOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.MomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov'], varargs=None, keywords='kwargs', defaults=(False,))
+paddle.fluid.optimizer.MomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov', 'regularization', 'name'], varargs=None, keywords=None, defaults=(False, None, None))
 paddle.fluid.optimizer.MomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.AdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon'], varargs=None, keywords='kwargs', defaults=(1e-06,))
+paddle.fluid.optimizer.AdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, None, None))
 paddle.fluid.optimizer.AdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.AdamOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon'], varargs=None, keywords='kwargs', defaults=(0.001, 0.9, 0.999, 1e-08))
+paddle.fluid.optimizer.AdamOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None))
 paddle.fluid.optimizer.AdamOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.AdamaxOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon'], varargs=None, keywords='kwargs', defaults=(0.001, 0.9, 0.999, 1e-08))
+paddle.fluid.optimizer.AdamaxOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None))
 paddle.fluid.optimizer.AdamaxOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'decay', 'epsilon'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06))
+paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'decay', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, None, None))
 paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'], varargs=None, keywords='kwargs', defaults=(0.0, 0.0, -0.5))
+paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.0, 0.0, -0.5, None, None))
 paddle.fluid.optimizer.FtrlOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0, False))
+paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, 0.0, False, None, None))
 paddle.fluid.optimizer.RMSPropOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'], varargs=None, keywords='kwargs', defaults=(1e-06, 0.95))
+paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, 0.95, None, None))
 paddle.fluid.optimizer.AdadeltaOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_window_rate', 'min_average_window', 'max_average_window'], varargs=None, keywords='kwargs', defaults=(10000, 10000))
+paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_window_rate', 'min_average_window', 'max_average_window', 'regularization', 'name'], varargs=None, keywords=None, defaults=(10000, 10000, None, None))
 paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
 paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
 paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None)
diff --git a/paddle/fluid/framework/details/cow_ptr.h b/paddle/fluid/framework/details/cow_ptr.h
index 4fb015b0ff..21f75957be 100644
--- a/paddle/fluid/framework/details/cow_ptr.h
+++ b/paddle/fluid/framework/details/cow_ptr.h
@@ -20,41 +20,79 @@ namespace paddle {
 namespace framework {
 namespace details {
 
-template <class T>
-class COWPtr {
+// Change it to thread safe flags if needed.
+class ThreadUnsafeOwnershipFlags {
  public:
-  typedef std::shared_ptr<T> RefPtr;
+  explicit ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {}
 
- private:
-  RefPtr m_sp;
+  ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete;
+  ThreadUnsafeOwnershipFlags& operator=(
+      const ThreadUnsafeOwnershipFlags& other) = delete;
+  ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default;
 
-  void detach() {
-    T* tmp = m_sp.get();
-    if (!(tmp == nullptr || m_sp.unique())) {
-      m_sp = RefPtr(new T(*tmp));
+  void SetOwnership(bool flag) { flag_ = flag; }
+
+  // Invoke the callback if it is not owned.
+  template <typename Callback>
+  void AcquireOwnershipOnce(Callback acquire) {
+    if (!flag_) {
+      acquire();
+      flag_ = true;
     }
   }
 
- public:
-  COWPtr() : m_sp(nullptr) {}
-  explicit COWPtr(T* t) : m_sp(t) {}
-  explicit COWPtr(const RefPtr& refptr) : m_sp(refptr) {}
+ private:
+  bool flag_;
+};
 
-  const T& Data() const { return operator*(); }
+// Copy-On-Write pointer.
+// It will hold a T* pointer, and only copy once when `MutableData` is invoked.
+//
+// The template parameter OwnershipFlags should have:
+//   * a constructor takes a bool. True if own.
+//   * SetOwnership(bool flag).
+//   * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not
+//     owned.
+//
+// https://en.wikipedia.org/wiki/Copy-on-write
+template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags>
+class COWPtr {
+ public:
+  // Ctor from raw pointer.
+  explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {}
 
-  T* MutableData() { return operator->(); }
+  // Move methods. Steal ownership from origin
+  COWPtr(COWPtr&& other)
+      : payload_(other.payload_), ownership_{std::move(other.ownership_)} {}
+  COWPtr& operator=(COWPtr&& origin) = default;
 
-  const T& operator*() const { return *m_sp; }
-  T& operator*() {
-    detach();
-    return *m_sp;
+  // Copy methods. Not own payload
+  COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {}
+  COWPtr& operator=(const COWPtr& other) {
+    payload_ = other.payload_;
+    ownership_.SetOwnership(false);
+    return *this;
   }
-  const T* operator->() const { return m_sp.operator->(); }
-  T* operator->() {
-    detach();
-    return m_sp.operator->();
+
+  // Access read only data.
+  const T& Data() const { return *payload_; }
+
+  // Access mutable data. If the data is not owned, the data will be copied
+  // before.
+  T* MutableData() {
+    ownership_.AcquireOwnershipOnce(
+        [this] { payload_.reset(new T(*payload_)); });
+    return payload_.get();
   }
+
+ private:
+  // Actual data pointer.
+  std::shared_ptr<T> payload_;
+
+  // Ownership flag.
+  OwnershipFlags ownership_;
 };
+
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/cow_ptr_test.cc b/paddle/fluid/framework/details/cow_ptr_test.cc
index 5b055d7cb4..d2142af277 100644
--- a/paddle/fluid/framework/details/cow_ptr_test.cc
+++ b/paddle/fluid/framework/details/cow_ptr_test.cc
@@ -30,14 +30,6 @@ TEST(COWPtr, all) {
   ASSERT_EQ(ptr2.Data(), 10);
 }
 
-TEST(COWPtr, change_old) {
-  COWPtr<int> ptr(new int{0});
-  COWPtr<int> ptr2 = ptr;
-  *ptr.MutableData() = 10;
-  ASSERT_EQ(ptr2.Data(), 0);
-  ASSERT_EQ(ptr.Data(), 10);
-}
-
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
index 8f319116ab..134fcee826 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
@@ -210,43 +210,6 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
   return recv_vars;
 }
 
-bool MultiDevSSAGraphBuilder::IsDistTrainOp(
-    ir::Node *node, const std::vector<std::string> &send_vars,
-    const std::vector<std::string> &recv_vars) const {
-  if (send_vars.size() == 0 || recv_vars.size() == 0) {
-    return false;
-  }
-
-  /**
-   * Check any of opvars contains `.block` and in sendvars
-   */
-  auto checker = [](const std::vector<std::string> &opvars,
-                    const std::vector<std::string> &rpc_vars) -> bool {
-    for (auto &var : opvars) {
-      // a variable name with the suffix `.block` means it's a splited
-      // variable by (DistributeTranspiler)
-      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
-      if (var.find(".block") != std::string::npos &&
-          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
-        return true;
-      }
-    }
-    return false;
-  };
-
-  std::vector<std::string> input_var_names;
-  std::vector<std::string> output_var_names;
-  for (ir::Node *input : node->inputs) {
-    input_var_names.push_back(input->Name());
-  }
-  for (ir::Node *output : node->outputs) {
-    output_var_names.push_back(output->Name());
-  }
-
-  return checker(output_var_names, send_vars) ||
-         checker(input_var_names, recv_vars);
-}
-
 size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
     const std::vector<std::string> &var_names) const {
   int64_t numel_sum = 0;
@@ -370,7 +333,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
         }
       }
       is_dist_train = true;
-    } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
+    } else if (boost::get<int>(node->Op()->GetAttr(
+                   OpProtoAndCheckerMaker::OpRoleAttrName())) ==
+               static_cast<int>(OpRole::kDist)) {
       int op_dev_id = CreateDistTrainOp(&result, node);
       if (node->Op()->Type() == "concat") {
         auto origin_param_name = node->Op()->OutputArgumentNames()[0];
@@ -736,6 +701,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
           .emplace(varname, op_dev_id);
     }
   } else {
+    LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
     PADDLE_THROW(
         "the distribute training related op should be in [split_byref, "
         "concat].");
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h
index 47aaa80f4d..cdf9f13cde 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h
@@ -51,12 +51,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
   int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
   int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
 
-  /**
-   * Is this operator as the end-point operator before/after send operator.
-   */
-  bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
-                     const std::vector<std::string> &recv_vars) const;
-
   std::vector<std::string> FindDistTrainSendVars(
       const std::vector<ir::Node *> &nodes) const;
 
diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h
index ba2c41eb89..7836ecb127 100644
--- a/paddle/fluid/framework/mixed_vector.h
+++ b/paddle/fluid/framework/mixed_vector.h
@@ -17,12 +17,10 @@
 #include <algorithm>
 #include <initializer_list>
 #include <memory>
-#include <utility>
 #include <vector>
-#include "paddle/fluid/framework/details/cow_ptr.h"
+
 #include "paddle/fluid/framework/tensor.h"
 #include "paddle/fluid/framework/tensor_util.h"
-#include "paddle/fluid/memory/memcpy.h"
 
 #include "glog/logging.h"
 
@@ -30,401 +28,206 @@ namespace paddle {
 namespace framework {
 
 #if defined(PADDLE_WITH_CUDA)
-namespace details {
-struct CUDABuffer {
-  void *data_{nullptr};
-  size_t size_{0};
-  platform::CUDAPlace place_;
-
-  CUDABuffer() {}
-  CUDABuffer(platform::Place place, size_t size)
-      : size_(size), place_(boost::get<platform::CUDAPlace>(place)) {
-    data_ = memory::Alloc(place_, size);
-  }
-
-  ~CUDABuffer() { ClearMemory(); }
-
-  CUDABuffer(const CUDABuffer &o) = delete;
-  CUDABuffer &operator=(const CUDABuffer &o) = delete;
-
-  void Resize(platform::Place place, size_t size) {
-    ClearMemory();
-    place_ = boost::get<platform::CUDAPlace>(place);
-    data_ = memory::Alloc(place_, size);
-    size_ = size;
-  }
-
-  void Swap(CUDABuffer &o) {
-    std::swap(data_, o.data_);
-    std::swap(place_, o.place_);
-    std::swap(size_, o.size_);
-  }
-
- private:
-  void ClearMemory() const {
-    if (data_) {
-      memory::Free(place_, data_);
-    }
-  }
-};
-}  // namespace details
-
 // Vector<T> implements the std::vector interface, and can get Data or
 // MutableData from any place. The data will be synced implicitly inside.
 template <typename T>
 class Vector {
  public:
   using value_type = T;
-  using iterator = typename std::vector<T>::iterator;
-  using const_iterator = typename std::vector<T>::const_iterator;
-
- private:
-  // The actual class to implement vector logic
-  class VectorData {
-   public:
-    VectorData() : flag_(kDataInCPU) {}
-    VectorData(size_t count, const T &value)
-        : cpu_(count, value), flag_(kDataInCPU) {}
-    VectorData(std::initializer_list<T> init) : cpu_(init), flag_(kDataInCPU) {}
-    template <typename U>
-    explicit VectorData(const std::vector<U> &dat)
-        : cpu_(dat), flag_(kDataInCPU) {}
-
-    VectorData(const VectorData &o) {
-      o.ImmutableCPU();
-      cpu_ = o.cpu_;
-      flag_ = kDataInCPU;
-    }
-
-    VectorData &operator=(const VectorData &o) {
-      o.ImmutableCPU();
-      cpu_ = o.cpu_;
-      flag_ = kDataInCPU;
-      details::CUDABuffer null;
-      gpu_.Swap(null);
-      return *this;
-    }
-
-    T &operator[](size_t i) {
-      MutableCPU();
-      return cpu_[i];
-    }
-
-    const T &operator[](size_t i) const {
-      ImmutableCPU();
-      return cpu_[i];
-    }
-
-    size_t size() const { return cpu_.size(); }
-
-    iterator begin() {
-      MutableCPU();
-      return cpu_.begin();
-    }
-
-    iterator end() {
-      MutableCPU();
-      return cpu_.end();
-    }
-
-    T &front() {
-      MutableCPU();
-      return cpu_.front();
-    }
-
-    T &back() {
-      MutableCPU();
-      return cpu_.back();
-    }
-
-    const_iterator begin() const {
-      ImmutableCPU();
-      return cpu_.begin();
-    }
-
-    const_iterator end() const {
-      ImmutableCPU();
-      return cpu_.end();
-    }
-
-    const T &back() const {
-      ImmutableCPU();
-      return cpu_.back();
-    }
-
-    T *data() { return &(*this)[0]; }
-
-    const T *data() const { return &(*this)[0]; }
-
-    const T &front() const {
-      ImmutableCPU();
-      return cpu_.front();
-    }
-
-    // assign this from iterator.
-    // NOTE: the iterator must support `end-begin`
-    template <typename Iter>
-    void assign(Iter begin, Iter end) {
-      MutableCPU();
-      cpu_.assign(begin, end);
-    }
-
-    // push_back. If the previous capacity is not enough, the memory will
-    // double.
-    void push_back(T elem) {
-      MutableCPU();
-      cpu_.push_back(elem);
-    }
-
-    // extend a vector by iterator.
-    // NOTE: the iterator must support end-begin
-    template <typename It>
-    void Extend(It begin, It end) {
-      MutableCPU();
-      auto out_it = std::back_inserter<std::vector<T>>(this->cpu_);
-      std::copy(begin, end, out_it);
-    }
-
-    // resize the vector
-    void resize(size_t size) {
-      MutableCPU();
-      cpu_.resize(size);
-    }
-
-    // get cuda ptr. immutable
-    const T *CUDAData(platform::Place place) const {
-      PADDLE_ENFORCE(platform::is_gpu_place(place),
-                     "CUDA Data must on CUDA place");
-      ImmutableCUDA(place);
-      return reinterpret_cast<T *>(gpu_.data_);
-    }
-
-    // get cuda ptr. mutable
-    T *CUDAMutableData(platform::Place place) {
-      const T *ptr = CUDAData(place);
-      flag_ = kDirty | kDataInCUDA;
-      return const_cast<T *>(ptr);
-    }
-
-    // clear
-    void clear() {
-      cpu_.clear();
-      flag_ = kDirty | kDataInCPU;
-    }
-
-    size_t capacity() const { return cpu_.capacity(); }
-
-    // reserve data
-    void reserve(size_t size) { cpu_.reserve(size); }
-
-    // implicit cast operator. Vector can be cast to std::vector implicitly.
-    operator std::vector<T>() const {
-      ImmutableCPU();
-      return cpu_;
-    }
-
-    bool operator==(const VectorData &other) const {
-      ImmutableCPU();
-      other.ImmutableCPU();
-      return cpu_ == other.cpu_;
-    }
-
-   private:
-    enum DataFlag {
-      kDataInCPU = 0x01,
-      kDataInCUDA = 0x02,
-      // kDirty means the data has been changed in one device.
-      kDirty = 0x10
-    };
-
-    void CopyToCPU() const {
-      // COPY GPU Data To CPU
-      void *src = gpu_.data_;
-      void *dst = cpu_.data();
-      memory::Copy(platform::CPUPlace(), dst, gpu_.place_, src, gpu_.size_,
-                   nullptr);
-    }
-
-    void MutableCPU() {
-      if (IsInCUDA() && IsDirty()) {
-        CopyToCPU();
-      }
-      flag_ = kDirty | kDataInCPU;
-    }
-
-    void ImmutableCUDA(platform::Place place) const {
-      if (IsDirty()) {
-        if (IsInCPU()) {
-          CopyCPUDataToCUDA(place);
-          UnsetFlag(kDirty);
-          SetFlag(kDataInCUDA);
-        } else if (IsInCUDA() &&
-                   !(boost::get<platform::CUDAPlace>(place) == gpu_.place_)) {
-          CopyCUDADataToAnotherPlace(place);
-          // Still dirty
-        } else {
-          // Dirty && DataInCUDA && Device is same
-          // Do nothing
-        }
-      } else {
-        if (!IsInCUDA()) {
-          // Even data is not dirty. However, data is not in CUDA. Copy data.
-          CopyCPUDataToCUDA(place);
-          SetFlag(kDataInCUDA);
-        } else if (!(boost::get<platform::CUDAPlace>(place) == gpu_.place_)) {
-          CopyCUDADataToAnotherPlace(place);
-        } else {
-          // Not Dirty && DataInCUDA && Device is same
-          // Do nothing.
-        }
-      }
-    }
-    void CopyCUDADataToAnotherPlace(const platform::Place &place) const {
-      details::CUDABuffer tmp(place, gpu_.size_);
-      const void *src = gpu_.data_;
-      void *dst = tmp.data_;
-
-      memory::Copy(tmp.place_, dst, gpu_.place_, src, gpu_.size_, nullptr);
-      gpu_.Swap(tmp);
-    }
-    void CopyCPUDataToCUDA(const platform::Place &place) const {
-      void *src = cpu_.data();
-      gpu_.Resize(place, cpu_.size() * sizeof(T));
-      void *dst = gpu_.data_;
-      auto stream = static_cast<platform::CUDADeviceContext *>(
-                        platform::DeviceContextPool::Instance().Get(place))
-                        ->stream();
-      memory::Copy(gpu_.place_, dst, platform::CPUPlace(), src, gpu_.size_,
-                   stream);
-    }
-
-    void ImmutableCPU() const {
-      if (IsDirty() && !IsInCPU()) {  // If data has been changed in CUDA, or
-                                      // CPU has no data.
-        CopyToCPU();
-        UnsetFlag(kDirty);
-      }
-      SetFlag(kDataInCPU);
-    }
-
-    void UnsetFlag(int flag) const { flag_ &= ~flag; }
-    void SetFlag(int flag) const { flag_ |= flag; }
-
-    bool IsDirty() const { return flag_ & kDirty; }
-
-    bool IsInCUDA() const { return flag_ & kDataInCUDA; }
 
-    bool IsInCPU() const { return flag_ & kDataInCPU; }
-
-    mutable std::vector<T> cpu_;
-    mutable details::CUDABuffer gpu_;
-    mutable int flag_;
-  };
-
- public:
   // Default ctor. Create empty Vector
-  Vector() : m_(new VectorData()) {}
+  Vector() { InitEmpty(); }
 
   // Fill vector with value. The vector size is `count`.
-  explicit Vector(size_t count, const T &value = T())
-      : m_(new VectorData(count, value)) {}
+  explicit Vector(size_t count, const T &value = T()) {
+    InitEmpty();
+    if (count != 0) {
+      resize(count);
+      T *ptr = begin();
+      for (size_t i = 0; i < count; ++i) {
+        ptr[i] = value;
+      }
+    }
+  }
 
   // Ctor with init_list
-  Vector(std::initializer_list<T> init) : m_(new VectorData(init)) {}
+  Vector(std::initializer_list<T> init) {
+    if (init.size() == 0) {
+      InitEmpty();
+    } else {
+      InitByIter(init.size(), init.begin(), init.end());
+    }
+  }
 
   // implicit cast from std::vector.
   template <typename U>
-  Vector(const std::vector<U> &dat) : m_(new VectorData(dat)) {  // NOLINT
+  Vector(const std::vector<U> &dat) {  // NOLINT
+    if (dat.size() == 0) {
+      InitEmpty();
+    } else {
+      InitByIter(dat.size(), dat.begin(), dat.end());
+    }
   }
 
   // Copy ctor
-  Vector(const Vector<T> &other) { m_ = other.m_; }
+  Vector(const Vector<T> &other) { this->operator=(other); }
 
   // Copy operator
   Vector<T> &operator=(const Vector<T> &other) {
-    m_ = other.m_;
+    if (other.size() != 0) {
+      this->InitByIter(other.size(), other.begin(), other.end());
+    } else {
+      InitEmpty();
+    }
     return *this;
   }
 
   // Move ctor
-  Vector(Vector<T> &&other) { m_ = std::move(other.m_); }
+  Vector(Vector<T> &&other) {
+    this->size_ = other.size_;
+    this->flag_ = other.flag_;
+    if (other.cuda_vec_.memory_size()) {
+      this->cuda_vec_.ShareDataWith(other.cuda_vec_);
+    }
+    if (other.cpu_vec_.memory_size()) {
+      this->cpu_vec_.ShareDataWith(other.cpu_vec_);
+    }
+  }
 
   // CPU data access method. Mutable.
-  T &operator[](size_t i) { return (*m_)[i]; }
+  T &operator[](size_t i) {
+    MutableCPU();
+    return const_cast<T *>(cpu_vec_.data<T>())[i];
+  }
 
   // CPU data access method. Immutable.
-  const T &operator[](size_t i) const { return (*m_)[i]; }
+  const T &operator[](size_t i) const {
+    ImmutableCPU();
+    return cpu_vec_.data<T>()[i];
+  }
 
   // std::vector iterator methods. Based on CPU data access method
-  size_t size() const { return m_->size(); }
+  size_t size() const { return size_; }
 
-  iterator begin() { return m_->begin(); }
+  T *begin() { return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); }
 
-  iterator end() { return m_->end(); }
+  T *end() {
+    return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
+  }
 
-  T &front() { return m_->front(); }
+  T &front() { return *begin(); }
 
-  T &back() { return m_->back(); }
+  T &back() {
+    auto it = end();
+    --it;
+    return *it;
+  }
 
-  const_iterator begin() const { return m_->begin(); }
+  const T *begin() const {
+    return capacity() == 0 ? &EmptyDummy() : &this->operator[](0);
+  }
 
-  const_iterator end() const { return m_->end(); }
+  const T *end() const {
+    return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
+  }
 
-  const_iterator cbegin() const { return begin(); }
+  const T *cbegin() const { return begin(); }
 
-  const_iterator cend() const { return end(); }
+  const T *cend() const { return end(); }
 
-  const T &back() const { return m_->back(); }
+  const T &back() const {
+    auto it = end();
+    --it;
+    return *it;
+  }
 
-  T *data() { return m_->data(); }
+  T *data() { return begin(); }
 
-  const T *data() const { return m_->data(); }
+  const T *data() const { return begin(); }
 
-  const T &front() const { return m_->front(); }
+  const T &front() const { return *begin(); }
   // end of std::vector iterator methods
 
   // assign this from iterator.
   // NOTE: the iterator must support `end-begin`
   template <typename Iter>
   void assign(Iter begin, Iter end) {
-    m_->assign(begin, end);
+    InitByIter(end - begin, begin, end);
   }
 
   // push_back. If the previous capacity is not enough, the memory will
   // double.
-  void push_back(T elem) { m_->push_back(elem); }
+  void push_back(T elem) {
+    if (size_ + 1 > capacity()) {
+      reserve((size_ + 1) << 1);
+    }
+    *end() = elem;
+    ++size_;
+  }
 
   // extend a vector by iterator.
   // NOTE: the iterator must support end-begin
   template <typename It>
   void Extend(It begin, It end) {
-    m_->Extend(begin, end);
+    size_t pre_size = size_;
+    resize(pre_size + (end - begin));
+    T *ptr = this->begin() + pre_size;
+    for (; begin < end; ++begin, ++ptr) {
+      *ptr = *begin;
+    }
   }
 
   // resize the vector
   void resize(size_t size) {
-    if (m_.Data().size() != size) {
-      m_->resize(size);
+    if (size + 1 <= capacity()) {
+      size_ = size;
+    } else {
+      MutableCPU();
+      Tensor cpu_tensor;
+      platform::Place cpu = platform::CPUPlace();
+      T *ptr = cpu_tensor.mutable_data<T>(
+          framework::make_ddim({static_cast<int64_t>(size)}), cpu);
+      const T *old_ptr =
+          cpu_vec_.memory_size() == 0 ? nullptr : cpu_vec_.data<T>();
+      if (old_ptr != nullptr) {
+        std::copy(old_ptr, old_ptr + size_, ptr);
+      }
+      size_ = size;
+      cpu_vec_.ShareDataWith(cpu_tensor);
     }
   }
 
   // get cuda ptr. immutable
   const T *CUDAData(platform::Place place) const {
-    return m_.Data().CUDAData(place);
+    PADDLE_ENFORCE(platform::is_gpu_place(place),
+                   "CUDA Data must on CUDA place");
+    ImmutableCUDA(place);
+    return cuda_vec_.data<T>();
   }
 
   // get cuda ptr. mutable
   T *CUDAMutableData(platform::Place place) {
-    return m_->CUDAMutableData(place);
+    const T *ptr = CUDAData(place);
+    flag_ = kDirty | kDataInCUDA;
+    return const_cast<T *>(ptr);
   }
 
   // clear
-  void clear() { m_->clear(); }
+  void clear() {
+    size_ = 0;
+    flag_ = kDirty | kDataInCPU;
+  }
 
-  size_t capacity() const { return m_->capacity(); }
+  size_t capacity() const {
+    return cpu_vec_.memory_size() / SizeOfType(typeid(T));
+  }
 
   // reserve data
-  void reserve(size_t size) { m_->reserve(size); }
+  void reserve(size_t size) {
+    size_t pre_size = size_;
+    resize(size);
+    resize(pre_size);
+  }
 
   // the unify method to access CPU or CUDA data. immutable.
   const T *Data(platform::Place place) const {
@@ -445,7 +248,12 @@ class Vector {
   }
 
   // implicit cast operator. Vector can be cast to std::vector implicitly.
-  operator std::vector<T>() const { return *m_; }
+  operator std::vector<T>() const {
+    std::vector<T> result;
+    result.resize(size());
+    std::copy(begin(), end(), result.begin());
+    return result;
+  }
 
   bool operator==(const Vector<T> &other) const {
     if (size() != other.size()) return false;
@@ -459,11 +267,118 @@ class Vector {
     return true;
   }
 
-  const void *Handle() const { return &m_.Data(); }
-
  private:
-  // Vector is an COW object.
-  details::COWPtr<VectorData> m_;
+  void InitEmpty() {
+    size_ = 0;
+    flag_ = kDataInCPU;
+  }
+
+  template <typename Iter>
+  void InitByIter(size_t size, Iter begin, Iter end) {
+    platform::Place cpu = platform::CPUPlace();
+    T *ptr = this->cpu_vec_.template mutable_data<T>(
+        framework::make_ddim({static_cast<int64_t>(size)}), cpu);
+    for (size_t i = 0; i < size; ++i) {
+      *ptr++ = *begin++;
+    }
+    flag_ = kDataInCPU | kDirty;
+    size_ = size;
+  }
+
+  enum DataFlag {
+    kDataInCPU = 0x01,
+    kDataInCUDA = 0x02,
+    // kDirty means the data has been changed in one device.
+    kDirty = 0x10
+  };
+
+  void CopyToCPU() const {
+    // COPY GPU Data To CPU
+    TensorCopy(cuda_vec_, platform::CPUPlace(), &cpu_vec_);
+    WaitPlace(cuda_vec_.place());
+  }
+
+  void MutableCPU() {
+    if (IsInCUDA() && IsDirty()) {
+      CopyToCPU();
+    }
+    flag_ = kDirty | kDataInCPU;
+  }
+
+  void ImmutableCUDA(platform::Place place) const {
+    if (IsDirty()) {
+      if (IsInCPU()) {
+        TensorCopy(cpu_vec_, boost::get<platform::CUDAPlace>(place),
+                   &cuda_vec_);
+        WaitPlace(place);
+        UnsetFlag(kDirty);
+        SetFlag(kDataInCUDA);
+      } else if (IsInCUDA() && !(place == cuda_vec_.place())) {
+        framework::Tensor tmp;
+        TensorCopy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
+        WaitPlace(cuda_vec_.place());
+        cuda_vec_.ShareDataWith(tmp);
+        // Still dirty
+      } else {
+        // Dirty && DataInCUDA && Device is same
+        // Do nothing
+      }
+    } else {
+      if (!IsInCUDA()) {
+        // Even data is not dirty. However, data is not in CUDA. Copy data.
+        TensorCopy(cpu_vec_, boost::get<platform::CUDAPlace>(place),
+                   &cuda_vec_);
+        WaitPlace(place);
+        SetFlag(kDataInCUDA);
+      } else if (!(place == cuda_vec_.place())) {
+        framework::Tensor tmp;
+        WaitPlace(cuda_vec_.place());
+        TensorCopy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
+        WaitPlace(cuda_vec_.place());
+        WaitPlace(place);
+        cuda_vec_.ShareDataWith(tmp);
+      } else {
+        // Not Dirty && DataInCUDA && Device is same
+        // Do nothing.
+      }
+    }
+  }
+
+  void ImmutableCPU() const {
+    if (IsDirty() &&
+        !IsInCPU()) {  // If data has been changed in CUDA, or CPU has no data.
+      CopyToCPU();
+      UnsetFlag(kDirty);
+    }
+    SetFlag(kDataInCPU);
+  }
+
+  void UnsetFlag(int flag) const { flag_ &= ~flag; }
+  void SetFlag(int flag) const { flag_ |= flag; }
+
+  bool IsDirty() const { return flag_ & kDirty; }
+
+  bool IsInCUDA() const { return flag_ & kDataInCUDA; }
+
+  bool IsInCPU() const { return flag_ & kDataInCPU; }
+
+  static void WaitPlace(const platform::Place place) {
+    if (platform::is_gpu_place(place)) {
+      platform::DeviceContextPool::Instance()
+          .Get(boost::get<platform::CUDAPlace>(place))
+          ->Wait();
+    }
+  }
+
+  static T &EmptyDummy() {
+    static T dummy = T();
+    return dummy;
+  }
+
+  mutable int flag_;
+  mutable Tensor cpu_vec_;
+  mutable Tensor cuda_vec_;
+  size_t size_;
 };
 
 #else  // PADDLE_WITH_CUDA
diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc
index 4fa047bf3e..df2a7a27ca 100644
--- a/paddle/fluid/framework/op_proto_maker.cc
+++ b/paddle/fluid/framework/op_proto_maker.cc
@@ -120,6 +120,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
           {static_cast<int>(OpRole::kForward),
            static_cast<int>(OpRole::kBackward),
            static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
+           static_cast<int>(OpRole::kDist), static_cast<int>(OpRole::kLRSched),
            static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
            static_cast<int>(OpRole::kLoss) |
                static_cast<int>(OpRole::kBackward),
diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h
index 18827385ad..4ed3cc45d6 100644
--- a/paddle/fluid/framework/op_proto_maker.h
+++ b/paddle/fluid/framework/op_proto_maker.h
@@ -26,7 +26,13 @@ enum class OpRole {
   kForward = 0x0000,
   kBackward = 0x0001,
   kOptimize = 0x0002,
+  // RPC role is for send/recv releated op
   kRPC = 0x0003,
+  // Dist role is for split_byref/split_selected_rows/concat
+  // used for distributed training.
+  kDist = 0x0004,
+  // Tag all learning rate scheduler operators.
+  kLRSched = 0x0005,
 
   kLoss = 0x0100,
   // The default value of op's role. This should be only used for unittests and
diff --git a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
index bf893e3256..36bbec4731 100644
--- a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc
@@ -103,108 +103,74 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
   input_slots->assign({input_tensor});
 }
 
-const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25,
-                                25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43,
-                                44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39,
-                                14, 15, 44, 22, 23, 23, 23, 23, 23, 23, 23};
-
-void TestLACPrediction(const std::string &model_path,
-                       const std::string &data_file, const int batch_size,
-                       const int repeat, bool use_analysis = false) {
-  AnalysisConfig cfg;
-  cfg.model_dir = model_path;
-  cfg.use_gpu = false;
-  cfg.device = 0;
-  cfg.specify_input_name = true;
-  cfg.enable_ir_optim = true;
+void SetConfig(AnalysisConfig *cfg) {
+  cfg->model_dir = FLAGS_infer_model;
+  cfg->use_gpu = false;
+  cfg->device = 0;
+  cfg->specify_input_name = true;
+  cfg->enable_ir_optim = true;
+}
 
-  std::vector<PaddleTensor> input_slots, outputs_slots;
-  DataRecord data(data_file, batch_size);
-  GetOneBatch(&input_slots, &data, batch_size);
-  std::unique_ptr<PaddlePredictor> predictor;
-  if (use_analysis) {
-    predictor =
-        CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(cfg);
-  } else {
-    predictor =
-        CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
-  }
-  for (int i = 0; i < FLAGS_burning; i++) {
-    predictor->Run(input_slots, &outputs_slots);
+void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
+  DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
+  std::vector<PaddleTensor> input_slots;
+  int epoch = FLAGS_test_all_data ? data.batched_datas.size() : 1;
+  LOG(INFO) << "number of samples: " << epoch;
+  for (int bid = 0; bid < epoch; ++bid) {
+    GetOneBatch(&input_slots, &data, FLAGS_batch_size);
+    (*inputs).emplace_back(input_slots);
   }
-  Timer timer;
-  if (FLAGS_test_all_data) {
-    LOG(INFO) << "test all data";
-    std::vector<std::vector<PaddleTensor>> input_slots_all;
-    for (size_t bid = 0; bid < data.batched_datas.size(); ++bid) {
-      GetOneBatch(&input_slots, &data, batch_size);
-      input_slots_all.emplace_back(input_slots);
-    }
-    LOG(INFO) << "total number of samples: " << data.datasets.size();
-    TestPrediction(cfg, input_slots_all, &outputs_slots, FLAGS_num_threads);
-    return;
-  }
-  timer.tic();
-  for (int i = 0; i < repeat; i++) {
-    predictor->Run(input_slots, &outputs_slots);
-  }
-  PrintTime(batch_size, repeat, 1, 0, timer.toc() / repeat);
+}
 
-  // check result
-  EXPECT_EQ(outputs_slots.size(), 1UL);
-  auto &out = outputs_slots[0];
-  size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
-                                [](int a, int b) { return a * b; });
-  size_t batch1_size = sizeof(lac_ref_data) / sizeof(int64_t);
-  PADDLE_ENFORCE_GT(size, 0);
-  EXPECT_GE(size, batch1_size);
-  int64_t *pdata = static_cast<int64_t *>(out.data.data());
-  for (size_t i = 0; i < batch1_size; ++i) {
-    EXPECT_EQ(pdata[i], lac_ref_data[i]);
-  }
+// Easy for profiling independently.
+TEST(Analyzer_LAC, profile) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
 
-  if (use_analysis) {
-    // run once for comparion as reference
-    auto ref_predictor =
-        CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
-    std::vector<PaddleTensor> ref_outputs_slots;
-    ref_predictor->Run(input_slots, &ref_outputs_slots);
-    CompareResult(ref_outputs_slots, outputs_slots);
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
 
-    AnalysisPredictor *analysis_predictor =
-        dynamic_cast<AnalysisPredictor *>(predictor.get());
-    auto &fuse_statis = analysis_predictor->analysis_argument()
-                            .Get<std::unordered_map<std::string, int>>(
-                                framework::ir::kFuseStatisAttr);
-    for (auto &item : fuse_statis) {
-      LOG(INFO) << "fused " << item.first << " " << item.second;
-    }
-    int num_ops = 0;
-    for (auto &node :
-         analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
-      if (node->IsFunction()) {
-        ++num_ops;
-      }
+  if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
+    // the first inference result
+    const int64_t lac_ref_data[] = {
+        24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, 25, 25, 25, 25,
+        44, 24, 25, 25, 25, 36, 42, 43, 44, 14, 15, 44, 14, 15, 44, 14,
+        15, 44, 38, 39, 14, 15, 44, 22, 23, 23, 23, 23, 23, 23, 23};
+    PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
+    size_t size = GetSize(outputs[0]);
+    size_t batch1_size = sizeof(lac_ref_data) / sizeof(int64_t);
+    PADDLE_ENFORCE_GE(size, batch1_size);
+    int64_t *pdata = static_cast<int64_t *>(outputs[0].data.data());
+    for (size_t i = 0; i < batch1_size; ++i) {
+      EXPECT_EQ(pdata[i], lac_ref_data[i]);
     }
-    LOG(INFO) << "has num ops: " << num_ops;
-    ASSERT_TRUE(fuse_statis.count("fc_fuse"));
-    ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
-    EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
-    EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 4);
-    EXPECT_EQ(num_ops, 11);
   }
 }
 
-TEST(Analyzer_LAC, native) {
-  LOG(INFO) << "LAC with native";
-  TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size,
-                    FLAGS_repeat);
+// Check the fuse status
+TEST(Analyzer_LAC, fuse_statis) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  int num_ops;
+  auto fuse_statis = GetFuseStatis(cfg, &num_ops);
+  ASSERT_TRUE(fuse_statis.count("fc_fuse"));
+  ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
+  EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
+  EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 4);
+  EXPECT_EQ(num_ops, 11);
 }
 
-TEST(Analyzer_LAC, analysis) {
-  LOG(INFO) << "LAC with analysis";
-  TestLACPrediction(FLAGS_infer_model, FLAGS_infer_data, FLAGS_batch_size,
-                    FLAGS_repeat, true);
+// Compare result of NativeConfig and AnalysisConfig
+TEST(Analyzer_LAC, compare) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
 }
 
 }  // namespace analysis
diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
index f8c651e32f..8cf230a51d 100644
--- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
@@ -95,97 +95,73 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
   }
 }
 
-// the first inference result
-const int chinese_ner_result_data[] = {30, 45, 41, 48, 17, 26,
-                                       48, 39, 38, 16, 25};
-
-void TestChineseNERPrediction(bool use_analysis) {
-  AnalysisConfig cfg;
-  cfg.prog_file = FLAGS_infer_model + "/__model__";
-  cfg.param_file = FLAGS_infer_model + "/param";
-  cfg.use_gpu = false;
-  cfg.device = 0;
-  cfg.specify_input_name = true;
-  cfg.enable_ir_optim = true;
-
-  std::vector<PaddleTensor> input_slots, outputs;
-  std::unique_ptr<PaddlePredictor> predictor;
-  Timer timer;
-  if (use_analysis) {
-    predictor =
-        CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(cfg);
-  } else {
-    predictor =
-        CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
-  }
+void SetConfig(AnalysisConfig *cfg) {
+  cfg->prog_file = FLAGS_infer_model + "/__model__";
+  cfg->param_file = FLAGS_infer_model + "/param";
+  cfg->use_gpu = false;
+  cfg->device = 0;
+  cfg->specify_input_name = true;
+  cfg->enable_ir_optim = true;
+}
 
-  if (FLAGS_test_all_data) {
-    LOG(INFO) << "test all data";
-    DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
-    std::vector<std::vector<PaddleTensor>> input_slots_all;
-    for (size_t bid = 0; bid < data.num_samples / FLAGS_batch_size; ++bid) {
-      PrepareInputs(&input_slots, &data, FLAGS_batch_size);
-      input_slots_all.emplace_back(input_slots);
-    }
-    LOG(INFO) << "total number of samples: " << data.num_samples;
-    TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
-    return;
-  }
-  // Prepare inputs.
+void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
   DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
-  PrepareInputs(&input_slots, &data, FLAGS_batch_size);
-
-  timer.tic();
-  for (int i = 0; i < FLAGS_repeat; i++) {
-    predictor->Run(input_slots, &outputs);
+  std::vector<PaddleTensor> input_slots;
+  int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
+  LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
+  for (int bid = 0; bid < epoch; ++bid) {
+    PrepareInputs(&input_slots, &data, FLAGS_batch_size);
+    (*inputs).emplace_back(input_slots);
   }
-  PrintTime(FLAGS_batch_size, FLAGS_repeat, 1, 0, timer.toc() / FLAGS_repeat);
+}
 
-  PADDLE_ENFORCE(outputs.size(), 1UL);
-  auto &out = outputs[0];
-  size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
-                                [](int a, int b) { return a * b; });
-  PADDLE_ENFORCE_GT(size, 0);
-  int64_t *result = static_cast<int64_t *>(out.data.data());
-  for (size_t i = 0; i < std::min(11UL, size); i++) {
-    PADDLE_ENFORCE(result[i], chinese_ner_result_data[i]);
-  }
+// Easy for profiling independently.
+TEST(Analyzer_Chinese_ner, profile) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
 
-  if (use_analysis) {
-    // run once for comparion as reference
-    auto ref_predictor =
-        CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
-    std::vector<PaddleTensor> ref_outputs_slots;
-    ref_predictor->Run(input_slots, &ref_outputs_slots);
-    CompareResult(ref_outputs_slots, outputs);
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
 
-    AnalysisPredictor *analysis_predictor =
-        dynamic_cast<AnalysisPredictor *>(predictor.get());
-    auto &fuse_statis = analysis_predictor->analysis_argument()
-                            .Get<std::unordered_map<std::string, int>>(
-                                framework::ir::kFuseStatisAttr);
-    for (auto &item : fuse_statis) {
-      LOG(INFO) << "fused " << item.first << " " << item.second;
-    }
-    int num_ops = 0;
-    for (auto &node :
-         analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
-      if (node->IsFunction()) {
-        ++num_ops;
-      }
+  if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
+    // the first inference result
+    const int chinese_ner_result_data[] = {30, 45, 41, 48, 17, 26,
+                                           48, 39, 38, 16, 25};
+    PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
+    size_t size = GetSize(outputs[0]);
+    PADDLE_ENFORCE_GT(size, 0);
+    int64_t *result = static_cast<int64_t *>(outputs[0].data.data());
+    for (size_t i = 0; i < std::min(11UL, size); i++) {
+      EXPECT_EQ(result[i], chinese_ner_result_data[i]);
     }
-    LOG(INFO) << "has num ops: " << num_ops;
-    ASSERT_TRUE(fuse_statis.count("fc_fuse"));
-    ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
-    EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
-    EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 2);
-    EXPECT_EQ(num_ops, 14);
   }
 }
 
-TEST(Analyzer_Chinese_ner, native) { TestChineseNERPrediction(false); }
+// Check the fuse status
+TEST(Analyzer_Chinese_ner, fuse_statis) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
 
-TEST(Analyzer_Chinese_ner, analysis) { TestChineseNERPrediction(true); }
+  int num_ops;
+  auto fuse_statis = GetFuseStatis(cfg, &num_ops);
+  ASSERT_TRUE(fuse_statis.count("fc_fuse"));
+  ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
+  EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
+  EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 2);
+  EXPECT_EQ(num_ops, 14);
+}
+
+// Compare result of NativeConfig and AnalysisConfig
+TEST(Analyzer_Chinese_ner, compare) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
+}
 
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
index df96be544e..14bdf76efc 100644
--- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc
@@ -25,6 +25,7 @@ struct DataRecord {
   std::vector<size_t> lod1, lod2, lod3;
   std::vector<std::vector<float>> rnn_link_data, rnn_week_datas,
       rnn_minute_datas;
+  size_t num_samples;  // total number of samples
   size_t batch_iter{0};
   size_t batch_size{1};
   DataRecord() = default;
@@ -97,6 +98,7 @@ struct DataRecord {
       week_data_all.push_back(std::move(week_data));
       minute_data_all.push_back(std::move(minute_data));
     }
+    num_samples = num_lines;
   }
 };
 void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
@@ -147,89 +149,72 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
   }
 }
 
-// Test with a really complicate model.
-void TestRNN1Prediction(bool use_analysis, bool activate_ir, int num_threads) {
-  AnalysisConfig config;
-  config.prog_file = FLAGS_infer_model + "/__model__";
-  config.param_file = FLAGS_infer_model + "/param";
-  config.use_gpu = false;
-  config.device = 0;
-  config.specify_input_name = true;
-  config.enable_ir_optim = activate_ir;
-  PADDLE_ENFORCE(config.ir_mode ==
-                 AnalysisConfig::IrPassMode::kExclude);  // default
-  config.ir_passes.clear();  // Do not exclude any pass.
-
-  int batch_size = FLAGS_batch_size;
+void SetConfig(AnalysisConfig *cfg) {
+  cfg->prog_file = FLAGS_infer_model + "/__model__";
+  cfg->param_file = FLAGS_infer_model + "/param";
+  cfg->use_gpu = false;
+  cfg->device = 0;
+  cfg->specify_input_name = true;
+  cfg->enable_ir_optim = true;
+  cfg->ir_passes.clear();  // Do not exclude any pass.
+}
 
-  auto base_predictor =
-      CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
-  auto predictor =
-      CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
-          config);
+void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
+  DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
   std::vector<PaddleTensor> input_slots;
-  DataRecord data(FLAGS_infer_data, batch_size);
-  // Prepare inputs.
-  PrepareInputs(&input_slots, &data, batch_size);
-  std::vector<PaddleTensor> outputs, base_outputs;
+  int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
+  LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
+  for (int bid = 0; bid < epoch; ++bid) {
+    PrepareInputs(&input_slots, &data, FLAGS_batch_size);
+    (*inputs).emplace_back(input_slots);
+  }
+}
 
-  base_predictor->Run(input_slots, &base_outputs);
+// Easy for profiling independently.
+TEST(Analyzer_rnn1, profile) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
 
   std::vector<std::vector<PaddleTensor>> input_slots_all;
-  input_slots_all.emplace_back(input_slots);
-  if (num_threads == 1) {
-    TestOneThreadPrediction(config, input_slots_all, &outputs);
-    CompareResult(outputs, base_outputs);
-  } else {
-    // only return the output of first thread
-    TestMultiThreadPrediction(config, input_slots_all, &outputs, num_threads);
-  }
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
+}
 
-  if (use_analysis && activate_ir) {
-    AnalysisPredictor *analysis_predictor =
-        dynamic_cast<AnalysisPredictor *>(predictor.get());
-    auto &fuse_statis = analysis_predictor->analysis_argument()
-                            .Get<std::unordered_map<std::string, int>>(
-                                framework::ir::kFuseStatisAttr);
-    for (auto &item : fuse_statis) {
-      LOG(INFO) << "fused " << item.first << " " << item.second;
-    }
+// Check the fuse status
+TEST(Analyzer_rnn1, fuse_statis) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
 
-    int num_ops = 0;
-    for (auto &node :
-         analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
-      if (node->IsFunction()) {
-        ++num_ops;
-      }
-    }
-    LOG(INFO) << "has num ops: " << num_ops;
+  int num_ops;
+  auto fuse_statis = GetFuseStatis(cfg, &num_ops);
+  ASSERT_TRUE(fuse_statis.count("fc_fuse"));
+  EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
+  EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2);  // bi-directional LSTM
+  EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
+  EXPECT_EQ(num_ops,
+            13);  // After graph optimization, only 13 operators exists.
+}
 
-    ASSERT_TRUE(fuse_statis.count("fc_fuse"));
-    EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
-    EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2);  // bi-directional LSTM
-    EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
-    EXPECT_EQ(num_ops,
-              13);  // After graph optimization, only 13 operators exists.
-  }
+// Compare result of NativeConfig and AnalysisConfig
+TEST(Analyzer_rnn1, compare) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
 }
 
-// Inference with analysis and IR, easy for profiling independently.
-TEST(Analyzer, rnn1) { TestRNN1Prediction(true, true, FLAGS_num_threads); }
+// Test Multi-Thread.
+TEST(Analyzer_rnn1, multi_thread) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
 
-// Other unit-tests of RNN1, test different options of use_analysis,
-// activate_ir and multi-threads.
-TEST(Analyzer, RNN_tests) {
-  int num_threads[2] = {1, 4};
-  for (auto i : num_threads) {
-    // Directly infer with the original model.
-    TestRNN1Prediction(false, false, i);
-    // Inference with the original model with the analysis turned on, the
-    // analysis module will transform the program to a data flow graph.
-    TestRNN1Prediction(true, false, i);
-    // Inference with analysis and IR. The IR module will fuse some large
-    // kernels.
-    TestRNN1Prediction(true, true, i);
-  }
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, 4 /* num_threads */);
 }
 
 }  // namespace inference
diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
index c40ea58eea..ba04d030b9 100644
--- a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
@@ -12,24 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "paddle/fluid/inference/analysis/analyzer.h"
-
-#include <google/protobuf/text_format.h>
-#include <gtest/gtest.h>
-#include <thread>  // NOLINT
-#include "paddle/fluid/framework/ir/fuse_pass_base.h"
-#include "paddle/fluid/framework/ir/pass.h"
-#include "paddle/fluid/inference/analysis/ut_helper.h"
-#include "paddle/fluid/inference/api/analysis_predictor.h"
-#include "paddle/fluid/inference/api/helper.h"
-#include "paddle/fluid/inference/api/paddle_inference_api.h"
-#include "paddle/fluid/inference/api/paddle_inference_pass.h"
-
-DEFINE_string(infer_model, "", "model path");
-DEFINE_string(infer_data, "", "data path");
-DEFINE_int32(batch_size, 1, "batch size.");
-DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
-DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
+#include "paddle/fluid/inference/tests/api/tester_helper.h"
 
 namespace paddle {
 namespace inference {
@@ -41,6 +24,7 @@ struct DataRecord {
   std::vector<size_t> lod;
   std::vector<std::vector<float>> rnn_link_data;
   std::vector<float> result_data;
+  size_t num_samples;  // total number of samples
   size_t batch_iter{0};
   size_t batch_size{1};
   DataRecord() = default;
@@ -100,6 +84,7 @@ struct DataRecord {
         result_data.insert(result_data.end(), tmp.begin(), tmp.end());
       }
     }
+    num_samples = num_lines / 2;
   }
 };
 void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
@@ -118,64 +103,58 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
   input_slots->assign({feed_tensor});
 }
 
-void CompareResult(const std::vector<PaddleTensor> &outputs,
-                   const std::vector<float> &base_result) {
-  PADDLE_ENFORCE_GT(outputs.size(), 0);
-  for (size_t i = 0; i < outputs.size(); i++) {
-    auto &out = outputs[i];
-    size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
-                                  [](int a, int b) { return a * b; });
-    PADDLE_ENFORCE_GT(size, 0);
-    float *data = static_cast<float *>(out.data.data());
-    for (size_t i = 0; i < size; i++) {
-      EXPECT_NEAR(data[i], base_result[i], 1e-3);
-    }
+void SetConfig(AnalysisConfig *cfg) {
+  cfg->prog_file = FLAGS_infer_model + "/__model__";
+  cfg->param_file = FLAGS_infer_model + "/param";
+  cfg->use_gpu = false;
+  cfg->device = 0;
+  cfg->specify_input_name = true;
+  cfg->enable_ir_optim = true;
+}
+
+void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
+  DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
+  std::vector<PaddleTensor> input_slots;
+  int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
+  LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
+  for (int bid = 0; bid < epoch; ++bid) {
+    PrepareInputs(&input_slots, &data, FLAGS_batch_size);
+    (*inputs).emplace_back(input_slots);
   }
 }
-// Test with a really complicate model.
-void TestRNN2Prediction() {
-  AnalysisConfig config;
-  config.prog_file = FLAGS_infer_model + "/__model__";
-  config.param_file = FLAGS_infer_model + "/param";
-  config.use_gpu = false;
-  config.device = 0;
-  config.specify_input_name = true;
-  config.enable_ir_optim = true;
-  PADDLE_ENFORCE(config.ir_mode ==
-                 AnalysisConfig::IrPassMode::kExclude);  // default
 
-  int batch_size = FLAGS_batch_size;
-  int num_times = FLAGS_repeat;
+// Easy for profiling independently.
+TEST(Analyzer_rnn2, profile) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
 
-  auto base_predictor =
-      CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
-  auto predictor =
-      CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
-          config);
-  std::vector<PaddleTensor> input_slots;
-  DataRecord data(FLAGS_infer_data, batch_size);
-  PrepareInputs(&input_slots, &data, batch_size);
-  std::vector<PaddleTensor> outputs, base_outputs;
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
 
-  Timer timer1;
-  timer1.tic();
-  for (int i = 0; i < num_times; i++) {
-    base_predictor->Run(input_slots, &base_outputs);
+  if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
+    // the first inference result
+    DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
+    PADDLE_ENFORCE_GT(outputs.size(), 0);
+    size_t size = GetSize(outputs[0]);
+    PADDLE_ENFORCE_GT(size, 0);
+    float *result = static_cast<float *>(outputs[0].data.data());
+    for (size_t i = 0; i < size; i++) {
+      EXPECT_NEAR(result[i], data.result_data[i], 1e-3);
+    }
   }
-  PrintTime(batch_size, num_times, 1, 0, timer1.toc() / num_times);
+}
 
-  Timer timer2;
-  timer2.tic();
-  for (int i = 0; i < num_times; i++) {
-    predictor->Run(input_slots, &outputs);
-  }
-  PrintTime(batch_size, num_times, 1, 0, timer2.toc() / num_times);
+// Compare result of NativeConfig and AnalysisConfig
+TEST(Analyzer_rnn2, compare) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
 
-  CompareResult(base_outputs, data.result_data);
-  CompareResult(outputs, data.result_data);
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
 }
 
-TEST(Analyzer, rnn2) { TestRNN2Prediction(); }
-
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
index 1472c475e4..340ef152f0 100644
--- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
@@ -46,54 +46,63 @@ struct DataReader {
   std::unique_ptr<std::ifstream> file;
 };
 
-void Main(int batch_size) {
-  // shape --
-  // Create Predictor --
-  AnalysisConfig config;
-  config.model_dir = FLAGS_infer_model;
-  config.use_gpu = false;
-  config.enable_ir_optim = true;
+void SetConfig(AnalysisConfig *cfg) {
+  cfg->model_dir = FLAGS_infer_model;
+  cfg->use_gpu = false;
+  cfg->device = 0;
+  cfg->specify_input_name = true;
+  cfg->enable_ir_optim = true;
+}
 
-  std::vector<PaddleTensor> input_slots, output_slots;
+void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
+  std::vector<PaddleTensor> input_slots;
   DataReader reader(FLAGS_infer_data);
-  std::vector<std::vector<PaddleTensor>> input_slots_all;
-
-  if (FLAGS_test_all_data) {
-    LOG(INFO) << "test all data";
-    int num_batches = 0;
-    while (reader.NextBatch(&input_slots, FLAGS_batch_size)) {
-      input_slots_all.emplace_back(input_slots);
-      ++num_batches;
-    }
-    LOG(INFO) << "total number of samples: " << num_batches * FLAGS_batch_size;
-    TestPrediction(config, input_slots_all, &output_slots, FLAGS_num_threads);
-    return;
+  int num_batches = 0;
+  while (reader.NextBatch(&input_slots, FLAGS_batch_size)) {
+    (*inputs).emplace_back(input_slots);
+    ++num_batches;
+    if (!FLAGS_test_all_data) return;
   }
+  LOG(INFO) << "total number of samples: " << num_batches * FLAGS_batch_size;
+}
 
-  // one batch starts
-  // data --
-  reader.NextBatch(&input_slots, FLAGS_batch_size);
-  input_slots_all.emplace_back(input_slots);
-  TestPrediction(config, input_slots_all, &output_slots, FLAGS_num_threads);
+// Easy for profiling independently.
+TEST(Analyzer_Text_Classification, profile) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
 
-  // Get output
-  LOG(INFO) << "get outputs " << output_slots.size();
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
 
-  for (auto &output : output_slots) {
-    LOG(INFO) << "output.shape: " << to_string(output.shape);
-    // no lod ?
-    CHECK_EQ(output.lod.size(), 0UL);
-    LOG(INFO) << "output.dtype: " << output.dtype;
-    std::stringstream ss;
-    for (int i = 0; i < 5; i++) {
-      ss << static_cast<float *>(output.data.data())[i] << " ";
+  if (FLAGS_num_threads == 1) {
+    // Get output
+    LOG(INFO) << "get outputs " << outputs.size();
+    for (auto &output : outputs) {
+      LOG(INFO) << "output.shape: " << to_string(output.shape);
+      // no lod ?
+      CHECK_EQ(output.lod.size(), 0UL);
+      LOG(INFO) << "output.dtype: " << output.dtype;
+      std::stringstream ss;
+      for (int i = 0; i < 5; i++) {
+        ss << static_cast<float *>(output.data.data())[i] << " ";
+      }
+      LOG(INFO) << "output.data summary: " << ss.str();
+      // one batch ends
     }
-    LOG(INFO) << "output.data summary: " << ss.str();
-    // one batch ends
   }
 }
 
-TEST(text_classification, basic) { Main(FLAGS_batch_size); }
+// Compare result of NativeConfig and AnalysisConfig
+TEST(Analyzer_Text_Classification, compare) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
+}
 
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
index a207c41b71..483ae66c5b 100644
--- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
@@ -49,84 +49,83 @@ Record ProcessALine(const std::string &line) {
   return record;
 }
 
-/*
- * Use the native and analysis fluid engine to inference the demo.
- * ocr, mobilenet and se_resnext50
- */
-void TestVisualPrediction(bool use_mkldnn) {
-  std::unique_ptr<PaddlePredictor> predictor;
-  AnalysisConfig cfg;
-  cfg.param_file = FLAGS_infer_model + "/__params__";
-  cfg.prog_file = FLAGS_infer_model + "/__model__";
-  cfg.use_gpu = false;
-  cfg._use_mkldnn = use_mkldnn;
-  cfg.device = 0;
-  cfg.enable_ir_optim = true;
+void SetConfig(AnalysisConfig *cfg) {
+  cfg->param_file = FLAGS_infer_model + "/__params__";
+  cfg->prog_file = FLAGS_infer_model + "/__model__";
+  cfg->use_gpu = false;
+  cfg->device = 0;
+  cfg->enable_ir_optim = true;
+  cfg->specify_input_name = true;
   // TODO(TJ): fix fusion gru
-  cfg.ir_passes.push_back("fc_gru_fuse_pass");
+  cfg->ir_passes.push_back("fc_gru_fuse_pass");
 #ifdef PADDLE_WITH_MKLDNN
+  cfg->_use_mkldnn = true;
   // disable mkldnn fuse since it should have some bugs
-  cfg.ir_passes.push_back("conv_relu_mkldnn_fuse_pass");
+  cfg->ir_passes.push_back("conv_relu_mkldnn_fuse_pass");
 #endif
-  predictor =
-      CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(cfg);
+}
 
-  // Only have single batch of data.
+void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
+  PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data.");
   std::string line;
   std::ifstream file(FLAGS_infer_data);
   std::getline(file, line);
   auto record = ProcessALine(line);
-  file.close();
 
-  // Inference.
   PaddleTensor input;
   input.shape = record.shape;
-  input.data =
-      PaddleBuf(record.data.data(), record.data.size() * sizeof(float));
   input.dtype = PaddleDType::FLOAT32;
+  size_t input_size = record.data.size() * sizeof(float);
+  input.data.Resize(input_size);
+  memcpy(input.data.data(), record.data.data(), input_size);
+  std::vector<PaddleTensor> input_slots;
+  input_slots.assign({input});
+  (*inputs).emplace_back(input_slots);
+}
 
-  std::vector<PaddleTensor> outputs_slots;
-  Timer timer;
-  timer.tic();
-  for (int i = 0; i < FLAGS_repeat; i++) {
-    predictor->Run({input}, &outputs_slots);
-  }
-  PrintTime(/*batch size*/ 1, FLAGS_repeat, /*num threads*/ 1, /*thread id*/ 0,
-            timer.toc() / FLAGS_repeat);
-
-  VLOG(3) << "output.size " << outputs_slots.size();
-
-  // run native as reference
-  auto ref_predictor =
-      CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(cfg);
-  std::vector<PaddleTensor> ref_outputs_slots;
-  ref_predictor->Run({input}, &ref_outputs_slots);
-  CompareResult(outputs_slots, ref_outputs_slots);
-  // print what are fused
-  AnalysisPredictor *analysis_predictor =
-      dynamic_cast<AnalysisPredictor *>(predictor.get());
-  auto &fuse_statis = analysis_predictor->analysis_argument()
-                          .Get<std::unordered_map<std::string, int>>(
-                              framework::ir::kFuseStatisAttr);
-  for (auto &item : fuse_statis) {
-    LOG(INFO) << "fused " << item.first << " " << item.second;
-  }
-  int num_ops = 0;
-  for (auto &node :
-       analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
-    if (node->IsFunction()) {
-      ++num_ops;
+// Easy for profiling independently.
+//  ocr, mobilenet and se_resnext50
+TEST(Analyzer_vis, profile) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  std::vector<PaddleTensor> outputs;
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
+
+  if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
+    const float ocr_result_data[] = {
+        5.273636460856323538e-08, 3.296741795111302054e-07,
+        1.873261190610264748e-08, 3.403730275408634043e-08,
+        3.383312474625199684e-08};
+    PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
+    size_t size = GetSize(outputs[0]);
+    PADDLE_ENFORCE_GT(size, 0);
+    float *result = static_cast<float *>(outputs[0].data.data());
+    for (size_t i = 0; i < std::min(5UL, size); i++) {
+      EXPECT_NEAR(result[i], ocr_result_data[i], 1e-3);
     }
   }
-  LOG(INFO) << "has num ops: " << num_ops;
 }
 
-TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
-#ifdef PADDLE_WITH_MKLDNN
-TEST(Analyzer_vis, analysis_mkldnn) {
-  TestVisualPrediction(/*use_mkldnn*/ true);
+// Check the fuse status
+TEST(Analyzer_vis, fuse_statis) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+  int num_ops;
+  GetFuseStatis(cfg, &num_ops);
+}
+
+// Compare result of NativeConfig and AnalysisConfig
+TEST(Analyzer_vis, compare) {
+  AnalysisConfig cfg;
+  SetConfig(&cfg);
+
+  std::vector<std::vector<PaddleTensor>> input_slots_all;
+  SetInput(&input_slots_all);
+  CompareNativeAndAnalysis(cfg, input_slots_all);
 }
-#endif
 
 }  // namespace analysis
 }  // namespace inference
diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h
index 43e97614e3..384a40a3f9 100644
--- a/paddle/fluid/inference/tests/api/tester_helper.h
+++ b/paddle/fluid/inference/tests/api/tester_helper.h
@@ -15,6 +15,7 @@
 #pragma once
 
 #include <gtest/gtest.h>
+#include <string>
 #include <thread>  // NOLINT
 #include <vector>
 #include "paddle/fluid/framework/ir/fuse_pass_base.h"
@@ -28,17 +29,18 @@
 DEFINE_string(infer_model, "", "model path");
 DEFINE_string(infer_data, "", "data file");
 DEFINE_int32(batch_size, 1, "batch size.");
-DEFINE_int32(burning, 0, "Burning before repeat.");
 DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
 DEFINE_bool(test_all_data, false, "Test the all dataset in data file.");
 DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
+DEFINE_bool(use_analysis, true,
+            "Running the inference program in analysis mode.");
 
 namespace paddle {
 namespace inference {
 
 void CompareResult(const std::vector<PaddleTensor> &outputs,
                    const std::vector<PaddleTensor> &ref_outputs) {
-  EXPECT_GT(outputs.size(), 0);
+  EXPECT_GT(outputs.size(), 0UL);
   EXPECT_EQ(outputs.size(), ref_outputs.size());
   for (size_t i = 0; i < outputs.size(); i++) {
     auto &out = outputs[i];
@@ -72,14 +74,50 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
   }
 }
 
+std::unique_ptr<PaddlePredictor> GetPrediction(AnalysisConfig config,
+                                               bool use_analysis = true) {
+  if (use_analysis) {
+    return CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
+        config);
+  } else {
+    return CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
+        config);
+  }
+}
+
+size_t GetSize(const PaddleTensor &out) {
+  return std::accumulate(out.shape.begin(), out.shape.end(), 1,
+                         [](int a, int b) { return a * b; });
+}
+
+std::unordered_map<std::string, int> GetFuseStatis(AnalysisConfig config,
+                                                   int *num_ops) {
+  auto predictor = GetPrediction(config);
+  AnalysisPredictor *analysis_predictor =
+      dynamic_cast<AnalysisPredictor *>(predictor.get());
+  auto &fuse_statis = analysis_predictor->analysis_argument()
+                          .Get<std::unordered_map<std::string, int>>(
+                              framework::ir::kFuseStatisAttr);
+  for (auto &item : fuse_statis) {
+    LOG(INFO) << "fused " << item.first << " " << item.second;
+  }
+  int num = 0;
+  for (auto &node :
+       analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
+    if (node->IsFunction()) {
+      ++num;
+    }
+  }
+  *num_ops = num;
+  return fuse_statis;
+}
+
 void TestOneThreadPrediction(
     AnalysisConfig config, const std::vector<std::vector<PaddleTensor>> inputs,
-    std::vector<PaddleTensor> *outputs) {
+    std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
   int batch_size = FLAGS_batch_size;
   int num_times = FLAGS_repeat;
-  auto predictor =
-      CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
-          config);
+  auto predictor = GetPrediction(config, use_analysis);
   Timer timer;
   timer.tic();
   for (int i = 0; i < num_times; i++) {
@@ -93,7 +131,8 @@ void TestOneThreadPrediction(
 
 void TestMultiThreadPrediction(
     AnalysisConfig config, const std::vector<std::vector<PaddleTensor>> inputs,
-    std::vector<PaddleTensor> *outputs, int num_threads) {
+    std::vector<PaddleTensor> *outputs, int num_threads,
+    bool use_analysis = true) {
   int batch_size = FLAGS_batch_size;
   int num_times = FLAGS_repeat;
   std::vector<std::thread> threads;
@@ -101,9 +140,7 @@ void TestMultiThreadPrediction(
   // TODO(yanchunwei): Bug here, the analyzer phase can't be parallelled
   // because AttentionLSTM's hard code nodeid will be damanged.
   for (int tid = 0; tid < num_threads; ++tid) {
-    predictors.emplace_back(
-        CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
-            config));
+    predictors.emplace_back(GetPrediction(config, use_analysis));
   }
   for (int tid = 0; tid < num_threads; ++tid) {
     threads.emplace_back([&, tid]() {
@@ -129,13 +166,25 @@ void TestMultiThreadPrediction(
 
 void TestPrediction(AnalysisConfig config,
                     const std::vector<std::vector<PaddleTensor>> inputs,
-                    std::vector<PaddleTensor> *outputs, int num_threads) {
+                    std::vector<PaddleTensor> *outputs, int num_threads,
+                    bool use_analysis = FLAGS_use_analysis) {
+  LOG(INFO) << "use_analysis: " << use_analysis;
   if (num_threads == 1) {
-    TestOneThreadPrediction(config, inputs, outputs);
+    TestOneThreadPrediction(config, inputs, outputs, use_analysis);
   } else {
-    TestMultiThreadPrediction(config, inputs, outputs, num_threads);
+    TestMultiThreadPrediction(config, inputs, outputs, num_threads,
+                              use_analysis);
   }
 }
 
+void CompareNativeAndAnalysis(
+    AnalysisConfig config,
+    const std::vector<std::vector<PaddleTensor>> inputs) {
+  std::vector<PaddleTensor> native_outputs, analysis_outputs;
+  TestOneThreadPrediction(config, inputs, &native_outputs, false);
+  TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
+  CompareResult(analysis_outputs, native_outputs);
+}
+
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h
index 84a584f424..5b27068c9e 100644
--- a/paddle/fluid/operators/adam_op.h
+++ b/paddle/fluid/operators/adam_op.h
@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
 
   const int64_t* rows_;
   int64_t row_numel_;
+  int64_t row_count_;
 
   SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
                     const T* beta2_pow, const T* mom1, T* mom1_out,
                     const T* mom2, T* mom2_out, const T* lr, const T* grad,
                     const T* param, T* param_out, const int64_t* rows,
-                    int64_t row_numel)
+                    int64_t row_numel, int64_t row_count)
       : beta1_(beta1),
         beta2_(beta2),
         epsilon_(epsilon),
@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
         param_(param),
         param_out_(param_out),
         rows_(rows),
-        row_numel_(row_numel) {}
+        row_numel_(row_numel),
+        row_count_(row_count) {}
+
+  inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
+    int64_t beg = 0, end = row_count_ - 1;
+    while (beg <= end) {
+      auto mid = ((beg + end) >> 1);
+      if (rows_[mid] == row)
+        return mid;
+      else if (rows_[mid] < row)
+        beg = mid + 1;
+      else
+        end = mid - 1;
+    }
+    return -1;
+  }
 
   inline HOSTDEVICE void operator()(size_t i) const {
+    int64_t row = i / row_numel_;
+    auto row_idx = BinarySearchInRows(row);
+    T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
+
+    // The following code is the same as dense
+    T mom1 = moment1_[i];
+    T mom2 = moment2_[i];
+    T lr = *lr_;
     T beta1_pow = *beta1_pow_;
     T beta2_pow = *beta2_pow_;
-    for (int64_t j = 0; j < row_numel_; ++j) {
-      T g = grad_[i * row_numel_ + j];
-      T mom1 = moment1_[rows_[i] * row_numel_ + j];
-      T mom2 = moment2_[rows_[i] * row_numel_ + j];
-      T lr = *lr_;
-      T p = param_[rows_[i] * row_numel_ + j];
-
-      lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
-
-      mom1 = beta1_ * mom1 + (1 - beta1_) * g;
-      mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
-      p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
-
-      moment1_out_[rows_[i] * row_numel_ + j] = mom1;
-      moment2_out_[rows_[i] * row_numel_ + j] = mom2;
-      param_out_[rows_[i] * row_numel_ + j] = p;
-    }  // for col id
+    T p = param_[i];
+
+    // Calculation
+    lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
+
+    mom1 = beta1_ * mom1 + (1 - beta1_) * g;
+    mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
+    p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
+
+    // Write back to global memory
+    moment1_out_[i] = mom1;
+    moment2_out_[i] = mom2;
+    param_out_[i] = p;
   }
 };
 
@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
         return;
       }
       // merge duplicated rows if any.
+      // The rows of grad_merge have been sorted inside MergeAdd functor
       scatter::MergeAdd<DeviceContext, T> merge_func;
-      auto grad_merge =
-          merge_func(ctx.template device_context<DeviceContext>(), grad);
+      auto& grad_merge = *(ctx.scope()
+                               .NewScope()
+                               .Var("sparse_adam_grad_merge")
+                               ->GetMutable<framework::SelectedRows>());
+      merge_func(ctx.template device_context<DeviceContext>(), grad,
+                 &grad_merge);
       auto& grad_tensor = grad_merge.value();
       const T* grad_data = grad_tensor.template data<T>();
       int64_t* rows = nullptr;
@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
           mom2.template data<T>(),
           mom2_out.template mutable_data<T>(ctx.GetPlace()),
           lr.template data<T>(), grad_data, param.template data<T>(),
-          param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
+          param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
+          grad_merge.rows().size());
       platform::ForRange<DeviceContext> for_range(
           static_cast<const DeviceContext&>(ctx.device_context()),
-          grad_merge.rows().size());
+          param.numel());
       for_range(functor);
     } else {
       PADDLE_THROW("Variable type not supported by adam_op");
diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h
index 85607a6b0e..daf06f370f 100644
--- a/paddle/fluid/operators/clip_op.h
+++ b/paddle/fluid/operators/clip_op.h
@@ -16,6 +16,7 @@ limitations under the License. */
 
 #include "paddle/fluid/framework/eigen.h"
 #include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/selected_rows_functor.h"
 #include "paddle/fluid/platform/transform.h"
 
 namespace paddle {
@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& context) const override {
     auto max = context.Attr<T>("max");
     auto min = context.Attr<T>("min");
-    auto* x = context.Input<Tensor>("X");
-    auto* out = context.Output<Tensor>("Out");
-    T* out_data = out->mutable_data<T>(context.GetPlace());
-    const T* x_data = x->data<T>();
-    int64_t numel = x->numel();
-    Transform<DeviceContext> trans;
-    trans(context.template device_context<DeviceContext>(), x_data,
-          x_data + numel, out_data, ClipFunctor<T>(min, max));
+    auto* x_var = context.InputVar("X");
+    if (x_var->IsType<framework::LoDTensor>()) {
+      auto* x = context.Input<framework::LoDTensor>("X");
+      auto* out = context.Output<framework::LoDTensor>("Out");
+      T* out_data = out->mutable_data<T>(context.GetPlace());
+      const T* x_data = x->data<T>();
+      int64_t numel = x->numel();
+      Transform<DeviceContext> trans;
+      trans(context.template device_context<DeviceContext>(), x_data,
+            x_data + numel, out_data, ClipFunctor<T>(min, max));
+    } else if (x_var->IsType<framework::SelectedRows>()) {
+      auto* x = context.Input<framework::SelectedRows>("X");
+      auto* out = context.Output<framework::SelectedRows>("Out");
+      PADDLE_ENFORCE_NE(x, out,
+                        "Inplace clip is not allowed when x is SelectedRows");
+      math::scatter::MergeAdd<DeviceContext, T> merge_func;
+      merge_func(context.template device_context<DeviceContext>(), *x, out);
+      auto* out_tensor = out->mutable_value();
+      auto* out_data = out_tensor->data<T>();
+      int64_t numel = out_tensor->numel();
+      Transform<DeviceContext> trans;
+      trans(context.template device_context<DeviceContext>(), out_data,
+            out_data + numel, out_data, ClipFunctor<T>(min, max));
+    } else {
+      PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
+    }
   }
 };
 
@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& context) const override {
     auto max = context.Attr<T>("max");
     auto min = context.Attr<T>("min");
-    auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
-    auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
+    auto* d_out =
+        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
+    auto* d_x =
+        context.Output<framework::LoDTensor>(framework::GradVarName("X"));
     if (d_x != nullptr) {
-      auto* x = context.Input<Tensor>("X");
+      auto* x = context.Input<framework::LoDTensor>("X");
       int64_t numel = d_out->numel();
       auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
       const T* d_out_data = d_out->data<T>();
diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt
index f4983c6543..5a058ddbc5 100644
--- a/paddle/fluid/operators/detection/CMakeLists.txt
+++ b/paddle/fluid/operators/detection/CMakeLists.txt
@@ -31,5 +31,6 @@ polygon_box_transform_op.cu)
 detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
 detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
 detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
+detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
 #Export local libraries to parent
 set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc
new file mode 100644
index 0000000000..b98190d40a
--- /dev/null
+++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc
@@ -0,0 +1,587 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <algorithm>
+#include <vector>
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/math_function.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+
+static constexpr int kROISize = 4;
+
+template <typename T>
+bool GT_E(T a, T b) {
+  return (a > b) || fabs(a - b) < 1e-4;
+}
+
+template <typename T>
+bool LT_E(T a, T b) {
+  return (a < b) || fabs(a - b) < 1e-4;
+}
+
+template <typename T>
+bool GT(T a, T b) {
+  return (a - b) > 1e-4;
+}
+
+/*
+*check if (x, y) is in the boundary of roi
+*/
+template <typename T>
+bool in_quad(T x, T y, T roi_x[], T roi_y[]) {
+  for (int i = 0; i < 4; i++) {
+    T xs = roi_x[i];
+    T ys = roi_y[i];
+    T xe = roi_x[(i + 1) % 4];
+    T ye = roi_y[(i + 1) % 4];
+    if (fabs(ys - ye) < 1e-4) {
+      if (fabs(y - ys) < 1e-4 && fabs(y - ye) < 1e-4 &&
+          GT_E<T>(x, std::min(xs, xe)) && LT_E<T>(x, std::max(xs, xe))) {
+        return true;
+      }
+    } else {
+      T intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs;
+      if (fabs(intersec_x - x) < 1e-4 && GT_E<T>(y, std::min(ys, ye)) &&
+          LT_E<T>(y, std::max(ys, ye))) {
+        return true;
+      }
+    }
+  }
+
+  int n_cross = 0;
+  for (int i = 0; i < 4; i++) {
+    T xs = roi_x[i];
+    T ys = roi_y[i];
+    T xe = roi_x[(i + 1) % 4];
+    T ye = roi_y[(i + 1) % 4];
+    if (fabs(ys - ye) < 1e-4) {
+      continue;
+    }
+    if (LT_E<T>(y, std::min(ys, ye)) || GT<T>(y, std::max(ys, ye))) {
+      continue;
+    }
+    T intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs;
+    if (fabs(intersec_x - x) < 1e-4) {
+      return true;
+    }
+    if (GT<T>(intersec_x, x)) {
+      n_cross++;
+    }
+  }
+  return (n_cross % 2 == 1);
+}
+
+/**
+ * Get the matrix of perspective transform.
+ *
+ * dx1 = x1 - x2
+ * dx2 = x3 - x2
+ * dx3 = x0 - x1 + x2 - x3
+ * dy1 = y1 - y2
+ * dy2 = y3 - y2
+ * dy3 = y0 - y1 + y2 - y3
+ *
+ * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1)
+ * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1)
+ * a13 = x0
+ * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1)
+ * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1)
+ * a23 = y0
+ * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1)
+ * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1)
+ * a33 = 1
+ *
+ */
+template <typename T>
+void get_transform_matrix(const int transformed_width,
+                          const int transformed_height, T roi_x[], T roi_y[],
+                          T matrix[]) {
+  T x0 = roi_x[0];
+  T x1 = roi_x[1];
+  T x2 = roi_x[2];
+  T x3 = roi_x[3];
+  T y0 = roi_y[0];
+  T y1 = roi_y[1];
+  T y2 = roi_y[2];
+  T y3 = roi_y[3];
+
+  // Estimate the height and width of RoI
+  T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1));
+  T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2));
+  T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3));
+  T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0));
+  T estimated_height = (len2 + len4) / 2.0;
+  T estimated_width = (len1 + len3) / 2.0;
+
+  // Get the normalized height and normalized width
+  int normalized_height = transformed_height;
+  int normalized_width =
+      std::round(estimated_width * (normalized_height - 1) / estimated_height) +
+      1;
+  normalized_width = std::min(normalized_width, transformed_width);
+
+  T dx1 = x1 - x2;
+  T dx2 = x3 - x2;
+  T dx3 = x0 - x1 + x2 - x3;
+  T dy1 = y1 - y2;
+  T dy2 = y3 - y2;
+  T dy3 = y0 - y1 + y2 - y3;
+
+  matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) /
+              (normalized_width - 1);
+  matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) /
+              (normalized_height - 1);
+  matrix[8] = 1;
+
+  matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) /
+              (normalized_width - 1);
+  matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) /
+              (normalized_height - 1);
+  matrix[5] = y0;
+
+  matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) /
+              (normalized_width - 1);
+  matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) /
+              (normalized_height - 1);
+  matrix[2] = x0;
+}
+
+/**
+ * Get the source coordinates in the input feature map.
+ *
+ * (u, v, w)^matrix = matrix * (out_w, out_h, 1)^matrix
+ *
+ * in_w = u / w
+ * in_h = v / w
+ *
+ */
+template <typename T>
+void get_source_coords(T matrix[], int out_w, int out_h, T* in_w, T* in_h) {
+  T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2];
+  T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5];
+  T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8];
+
+  in_w[0] = u / w;
+  in_h[0] = v / w;
+}
+
+/**
+ * Perform bilinear interpolation in the input feature map.
+ */
+template <typename T>
+void bilinear_interpolate(const T* in_data, const int channels, const int width,
+                          const int height, int in_n, int in_c, T in_w, T in_h,
+                          T* val) {
+  // Deal with cases that source coords are out of feature map boundary
+  if (GT<T>(-0.5, in_w) || GT<T>(in_w, width - 0.5) || GT<T>(-0.5, in_h) ||
+      GT<T>(in_h, height - 0.5)) {
+    // empty
+    val[0] = 0.0;
+    return;
+  }
+
+  if (GT<T>(0, in_w)) {
+    in_w = 0;
+  }
+  if (GT<T>(0, in_h)) {
+    in_h = 0;
+  }
+
+  int in_w_floor = floor(in_w);
+  int in_h_floor = floor(in_h);
+  int in_w_ceil;
+  int in_h_ceil;
+
+  if (GT_E<T>(in_w_floor, width - 1)) {
+    in_w_ceil = in_w_floor = width - 1;
+    in_w = static_cast<T>(in_w_floor);
+  } else {
+    in_w_ceil = in_w_floor + 1;
+  }
+
+  if (GT_E<T>(in_h_floor, height - 1)) {
+    in_h_ceil = in_h_floor = height - 1;
+    in_h = static_cast<T>(in_h_floor);
+  } else {
+    in_h_ceil = in_h_floor + 1;
+  }
+  T w_floor = in_w - in_w_floor;
+  T h_floor = in_h - in_h_floor;
+  T w_ceil = 1 - w_floor;
+  T h_ceil = 1 - h_floor;
+  const T* data = in_data + (in_n * channels + in_c) * height * width;
+  // Do bilinear interpolation
+  T v1 = data[in_h_floor * width + in_w_floor];
+  T v2 = data[in_h_ceil * width + in_w_floor];
+  T v3 = data[in_h_ceil * width + in_w_ceil];
+  T v4 = data[in_h_floor * width + in_w_ceil];
+  T w1 = w_ceil * h_ceil;
+  T w2 = w_ceil * h_floor;
+  T w3 = w_floor * h_floor;
+  T w4 = w_floor * h_ceil;
+  val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
+}
+
+template <typename T>
+class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* in = ctx.Input<framework::Tensor>("X");
+    auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
+    auto* out = ctx.Output<framework::Tensor>("Out");
+
+    auto transformed_height = ctx.Attr<int>("transformed_height");
+    auto transformed_width = ctx.Attr<int>("transformed_width");
+    auto spatial_scale = ctx.Attr<float>("spatial_scale");
+
+    auto in_dims = in->dims();
+    int channels = in_dims[1];
+    int in_height = in_dims[2];
+    int in_width = in_dims[3];
+    int rois_num = rois->dims()[0];
+
+    const T* input_data = in->data<T>();
+
+    framework::Tensor roi2image;
+    roi2image.Resize({rois_num});
+    int* roi2image_data = roi2image.mutable_data<int>(ctx.GetPlace());
+    auto lod = rois->lod().back();
+    for (int i = 0; i < lod.size() - 1; ++i) {
+      for (int j = lod[i]; j < lod[i + 1]; ++j) {
+        roi2image_data[j] = i;
+      }
+    }
+
+    T* output_data = out->mutable_data<T>(ctx.GetPlace());
+    const T* rois_data = rois->data<T>();
+
+    for (int n = 0; n < rois_num; ++n) {
+      const T* n_rois = rois_data + n * 8;
+      T roi_x[4];
+      T roi_y[4];
+      for (int k = 0; k < 4; ++k) {
+        roi_x[k] = n_rois[2 * k] * spatial_scale;
+        roi_y[k] = n_rois[2 * k + 1] * spatial_scale;
+      }
+      int image_id = roi2image_data[n];
+      // Get transform matrix
+      T transform_matrix[9];
+      get_transform_matrix<T>(transformed_width, transformed_height, roi_x,
+                              roi_y, transform_matrix);
+
+      for (int c = 0; c < channels; ++c) {
+        for (int out_h = 0; out_h < transformed_height; ++out_h) {
+          for (int out_w = 0; out_w < transformed_width; ++out_w) {
+            int out_index =
+                n * channels * transformed_height * transformed_width +
+                c * transformed_height * transformed_width +
+                out_h * transformed_width + out_w;
+            T in_w, in_h;
+            get_source_coords<T>(transform_matrix, out_w, out_h, &in_w, &in_h);
+            if (in_quad<T>(in_w, in_h, roi_x, roi_y)) {
+              if (GT<T>(-0.5, in_w) ||
+                  GT<T>(in_w, static_cast<T>(in_width - 0.5)) ||
+                  GT<T>(-0.5, in_h) ||
+                  GT<T>(in_h, static_cast<T>(in_height - 0.5))) {
+                output_data[out_index] = 0.0;
+              } else {
+                bilinear_interpolate(input_data, channels, in_width, in_height,
+                                     image_id, c, in_w, in_h,
+                                     output_data + out_index);
+              }
+            } else {
+              output_data[out_index] = 0.0;
+            }
+          }
+        }
+      }
+    }
+  }
+};
+
+template <typename T>
+T get_feature_gradient(T xs, T ys, int w, int h, const int width,
+                       const int height) {
+  if (GT<T>(-0.5, xs) || GT<T>(xs, width - 0.5) || GT<T>(-0.5, ys) ||
+      GT<T>(ys, height - 0.5)) {
+    return 0;
+  }
+
+  if (GT<T>(0, xs)) {
+    xs = 0;
+  }
+  if (GT<T>(0, ys)) {
+    ys = 0;
+  }
+
+  int xs_floor = floor(xs);
+  int ys_floor = floor(ys);
+  int xs_ceil;
+  int ys_ceil;
+
+  if (GT_E(xs_floor, width - 1)) {
+    xs_ceil = xs_floor = width - 1;
+    xs = static_cast<T>(xs_floor);
+  } else {
+    xs_ceil = xs_floor + 1;
+  }
+
+  if (GT_E(ys_floor, height - 1)) {
+    ys_ceil = ys_floor = height - 1;
+    ys = static_cast<T>(ys_floor);
+  } else {
+    ys_ceil = ys_floor + 1;
+  }
+
+  T weight = 0;
+  if (w == xs_floor) {
+    if (h == ys_floor) {
+      weight = (w + 1 - xs) * (h + 1 - ys);
+    } else if (h == ys_ceil) {
+      weight = (w + 1 - xs) * (ys + 1 - h);
+    }
+  } else if (w == xs_ceil) {
+    if (h == ys_floor) {
+      weight = (xs + 1 - w) * (h + 1 - ys);
+    } else if (h == ys_ceil) {
+      weight = (xs + 1 - w) * (ys + 1 - h);
+    }
+  }
+  return weight;
+}
+
+template <typename T>
+class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* in = ctx.Input<framework::Tensor>("X");
+    auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
+    auto* out_grad =
+        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
+    auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
+
+    auto transformed_height = ctx.Attr<int>("transformed_height");
+    auto transformed_width = ctx.Attr<int>("transformed_width");
+    auto spatial_scale = ctx.Attr<float>("spatial_scale");
+
+    auto in_dims = in->dims();
+    int batch_size = in_dims[0];
+    int channels = in_dims[1];
+    int in_height = in_dims[2];
+    int in_width = in_dims[3];
+    int rois_num = rois->dims()[0];
+
+    T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
+    const T* out_grad_data = out_grad->data<T>();
+    const T* rois_data = rois->data<T>();
+
+    framework::Tensor roi2image;
+    roi2image.Resize({rois_num});
+    int* roi2image_data = roi2image.mutable_data<int>(ctx.GetPlace());
+    auto lod = rois->lod().back();
+    for (int i = 0; i < lod.size() - 1; ++i) {
+      for (int j = lod[i]; j < lod[i + 1]; ++j) {
+        roi2image_data[j] = i;
+      }
+    }
+
+    for (int n = 0; n < batch_size; ++n) {
+      for (int c = 0; c < channels; ++c) {
+        for (int in_h = 0; in_h < in_height; ++in_h) {
+          for (int in_w = 0; in_w < in_width; ++in_w) {
+            T gradient = 0.0;
+            for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) {
+              const T* rois = rois_data + roi_idx * 8;
+              T roi_x[4];
+              T roi_y[4];
+              for (int k = 0; k < 4; ++k) {
+                roi_x[k] = rois[2 * k] * spatial_scale;
+                roi_y[k] = rois[2 * k + 1] * spatial_scale;
+              }
+
+              // Get transform matrix
+              T matrix[9];
+              get_transform_matrix<T>(transformed_width, transformed_height,
+                                      roi_x, roi_y, matrix);
+              const T* out_grad_ptr = out_grad_data +
+                                      (roi_idx * channels + c) *
+                                          transformed_height *
+                                          transformed_width;
+              for (int out_h = 0; out_h < transformed_height; ++out_h) {
+                for (int out_w = 0; out_w < transformed_width; ++out_w) {
+                  T src_w;
+                  T src_h;
+                  get_source_coords<T>(matrix, out_w, out_h, &src_w, &src_h);
+                  if (in_quad<T>(src_w, src_h, roi_x, roi_y)) {
+                    if (GT<T>(-0.5, src_w) ||
+                        GT<T>(src_w, static_cast<T>(in_width - 0.5)) ||
+                        GT<T>(-0.5, src_h) ||
+                        GT<T>(src_h, static_cast<T>(in_height - 0.5))) {
+                      continue;
+                    }
+                    T weight = get_feature_gradient<T>(src_w, src_h, in_w, in_h,
+                                                       in_width, in_height);
+                    gradient +=
+                        out_grad_ptr[out_h * transformed_width + out_w] *
+                        weight;
+                  }
+                }
+              }
+            }
+            int out_idx = (n * channels + c) * in_height * in_width +
+                          in_h * in_width + in_w;
+            in_grad_data[out_idx] = gradient;
+          }
+        }
+      }
+    }
+  }
+};
+
+class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "Input(X) of ROIPerspectiveTransformOp should not be null.");
+    PADDLE_ENFORCE(
+        ctx->HasInput("ROIs"),
+        "Input(ROIs) of ROIPerspectiveTransformOp should not be null.");
+    PADDLE_ENFORCE(
+        ctx->HasOutput("Out"),
+        "Output(Out) of ROIPerspectiveTransformOp should not be null.");
+    auto input_dims = ctx->GetInputDim("X");
+    auto rois_dims = ctx->GetInputDim("ROIs");
+
+    PADDLE_ENFORCE(input_dims.size() == 4,
+                   "The format of input tensor is NCHW.");
+    PADDLE_ENFORCE(rois_dims.size() == 2,
+                   "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)"
+                   "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]");
+    PADDLE_ENFORCE(rois_dims[1] == 8,
+                   "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)"
+                   "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...].");
+
+    int transformed_height = ctx->Attrs().Get<int>("transformed_height");
+    int transformed_width = ctx->Attrs().Get<int>("transformed_width");
+    float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
+
+    PADDLE_ENFORCE_GT(transformed_height, 0,
+                      "The transformed output height must greater than 0");
+    PADDLE_ENFORCE_GT(transformed_width, 0,
+                      "The transformed output width must greater than 0");
+    PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
+                      "The spatial scale must greater than 0");
+    std::vector<int64_t> out_dims_v({rois_dims[0],   // num_rois
+                                     input_dims[1],  // channels
+                                     static_cast<int64_t>(transformed_height),
+                                     static_cast<int64_t>(transformed_width)});
+    auto out_dims = framework::make_ddim(out_dims_v);
+
+    ctx->SetOutputDim("Out", out_dims);
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    return framework::OpKernelType(
+        framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
+        ctx.device_context());
+  }
+};
+
+class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+                   "The gradient of Out should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
+                   "The gradient of X should not be null.");
+    ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    return framework::OpKernelType(
+        framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
+        ctx.device_context());
+  }
+};
+
+class ROIPerspectiveTransformOpMaker
+    : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("X",
+             "(Tensor), "
+             "the input of ROIPerspectiveTransformOp. "
+             "The format of input tensor is NCHW. Where N is batch size, "
+             "C is the number of input channels, "
+             "H is the height of the feature, and "
+             "W is the width of the feature.");
+    AddInput("ROIs",
+             "(LoDTensor), "
+             "ROIs (Regions of Interest) to be transformed. "
+             "should be a 2-D LoDTensor of shape (num_rois, 8)"
+             "given as [[x1, y1, x2, y2, x3, y3, x4, y4], ...]."
+             "(x1, y1) is the top left coordinates, and "
+             "(x2, y2) is the top right coordinates, and"
+             "(x3, y3) is the bottom right coordinates, and"
+             "(x4, y4) is the bottom left coordinates.");
+    AddOutput(
+        "Out",
+        "(Tensor), "
+        "The output of ROIPerspectiveTransformOp is a 4-D tensor with shape "
+        "(num_rois, channels, transformed_h, transformed_w).");
+    AddAttr<float>("spatial_scale",
+                   "(float, default 1.0), "
+                   "Spatial scale factor to scale ROI coords.")
+        .SetDefault(1.0);
+    AddAttr<int>("transformed_height",
+                 "(int, default 1), "
+                 "The height of transformed output.")
+        .SetDefault(1);
+    AddAttr<int>("transformed_width",
+                 "(int, default 1), "
+                 "The width of transformed output.")
+        .SetDefault(1);
+    AddComment(R"DOC(
+**ROIPerspectiveTransform Operator**
+
+    )DOC");
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(roi_perspective_transform, ops::ROIPerspectiveTransformOp,
+                  ops::ROIPerspectiveTransformOpMaker,
+                  paddle::framework::DefaultGradOpDescMaker<true>);
+REGISTER_OPERATOR(roi_perspective_transform_grad,
+                  ops::ROIPerspectiveTransformGradOp);
+REGISTER_OP_CPU_KERNEL(roi_perspective_transform,
+                       ops::CPUROIPerspectiveTransformOpKernel<float>);
+REGISTER_OP_CPU_KERNEL(roi_perspective_transform_grad,
+                       ops::CPUROIPerspectiveTransformGradOpKernel<float>);
diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu
new file mode 100644
index 0000000000..b683b7573d
--- /dev/null
+++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu
@@ -0,0 +1,523 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <algorithm>
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+
+namespace paddle {
+namespace operators {
+
+// CUDA: index helpers
+#define idx4_4(index, d1, d2, d3, d4) (index % d4)
+#define idx4_3(index, d1, d2, d3, d4) ((index / d4) % d3)
+#define idx4_2(index, d1, d2, d3, d4) ((index / d4 / d3) % d2)
+#define idx4_1(index, d1, d2, d3, d4) ((index / d4 / d3 / d2) % d1)
+
+#define CUDA_1D_KERNEL_LOOP(i, n)                              \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+       i += blockDim.x * gridDim.x)
+
+template <typename T>
+__device__ bool GT_E(T a, T b) {
+  return (a > b) || fabs(a - b) < 1e-4;
+}
+
+template <typename T>
+__device__ bool LT_E(T a, T b) {
+  return (a < b) || fabs(a - b) < 1e-4;
+}
+
+template <typename T>
+__device__ bool GT(T a, T b) {
+  return (a - b) > 1e-4;
+}
+
+template <typename T>
+__device__ T max(T a, T b) {
+  return a > b ? a : b;
+}
+
+template <typename T>
+__device__ T min(T a, T b) {
+  return a < b ? a : b;
+}
+
+/*
+* check if (x, y) is in the boundary of roi
+*/
+template <typename T>
+__device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) {
+  for (int i = 0; i < 4; i++) {
+    T start_w = roi_x[i];
+    T start_h = roi_y[i];
+    T end_w = roi_x[(i + 1) % 4];
+    T end_h = roi_y[(i + 1) % 4];
+    if (fabs(start_h - end_h) < 1e-4) {
+      if (fabs(y - start_h) < 1e-4 && fabs(y - end_h) < 1e-4 &&
+          GT_E<T>(x, min<T>(start_w, end_w)) &&
+          LT_E<T>(x, max<T>(start_w, end_w))) {
+        return true;
+      }
+    } else {
+      T intersec_x =
+          (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w;
+      if (fabs(intersec_x - x) < 1e-4 && GT_E(y, min<T>(start_h, end_h)) &&
+          LT_E<T>(y, max<T>(start_h, end_h))) {
+        return true;
+      }
+    }
+  }
+
+  int n_cross = 0;
+  for (int i = 0; i < 4; i++) {
+    T start_w = roi_x[i];
+    T start_h = roi_y[i];
+    T end_w = roi_x[(i + 1) % 4];
+    T end_h = roi_y[(i + 1) % 4];
+    if (fabs(start_h - end_h) < 1e-4) {
+      continue;
+    }
+    if (LT_E<T>(y, min<T>(start_h, end_h)) ||
+        GT<T>(y, max<T>(start_h, end_h))) {
+      continue;
+    }
+    T intersec_x =
+        (y - start_h) * (end_w - start_w) / (end_h - start_h) + start_w;
+    if (fabs(intersec_x - x) < 1e-4) {
+      return true;
+    }
+    if (GT<T>(intersec_x, x)) {
+      n_cross++;
+    }
+  }
+  return (n_cross % 2 == 1);
+}
+
+/**
+ * Perform bilinear interpolation in the input feature map.
+ */
+template <typename T>
+__device__ void bilinear_interpolate(const T* in_data, const int channels,
+                                     const int width, const int height,
+                                     int in_n, int in_c, T in_w, T in_h,
+                                     T* val) {
+  // Deal with cases that source coords are out of feature map boundary
+  if (GT<T>(-0.5, in_w) || GT<T>(in_w, width - 0.5) || GT<T>(-0.5, in_h) ||
+      GT<T>(in_h, height - 0.5)) {
+    val[0] = 0.0;
+    return;
+  }
+
+  if (GT<T>(0, in_w)) {
+    in_w = 0;
+  }
+  if (GT<T>(0, in_h)) {
+    in_h = 0;
+  }
+
+  int in_w_floor = floor(in_w);
+  int in_h_floor = floor(in_h);
+  int in_w_ceil;
+  int in_h_ceil;
+
+  if (GT_E<T>(in_w_floor, width - 1)) {
+    in_w_ceil = in_w_floor = width - 1;
+    in_w = static_cast<T>(in_w_floor);
+  } else {
+    in_w_ceil = in_w_floor + 1;
+  }
+
+  if (GT_E<T>(in_h_floor, height - 1)) {
+    in_h_ceil = in_h_floor = height - 1;
+    in_h = static_cast<T>(in_h_floor);
+  } else {
+    in_h_ceil = in_h_floor + 1;
+  }
+
+  T w_floor = in_w - in_w_floor;
+  T h_floor = in_h - in_h_floor;
+  T w_ceil = 1 - w_floor;
+  T h_ceil = 1 - h_floor;
+  const T* data = in_data + (in_n * channels + in_c) * height * width;
+  // Do bilinear interpolation
+  T v1 = data[in_h_floor * width + in_w_floor];
+  T v2 = data[in_h_ceil * width + in_w_floor];
+  T v3 = data[in_h_ceil * width + in_w_ceil];
+  T v4 = data[in_h_floor * width + in_w_ceil];
+  T w1 = w_ceil * h_ceil;
+  T w2 = w_ceil * h_floor;
+  T w3 = w_floor * h_floor;
+  T w4 = w_floor * h_ceil;
+  val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
+}
+
+/**
+ * Get the source coordinates in the input feature map.
+ *
+ * (u, v, w)^matrix = T * (out_w, out_h, 1)^matrix
+ *
+ * in_w = u / w
+ * in_h = v / w
+ *
+ */
+template <typename T>
+__device__ void get_source_coords(T matrix[], int out_w, int out_h, T* in_w,
+                                  T* in_h) {
+  T u = matrix[0] * out_w + matrix[1] * out_h + matrix[2];
+  T v = matrix[3] * out_w + matrix[4] * out_h + matrix[5];
+  T w = matrix[6] * out_w + matrix[7] * out_h + matrix[8];
+
+  in_w[0] = u / w;
+  in_h[0] = v / w;
+}
+
+/**
+ * Get the matrix of perspective transform.
+ *
+ * dx1 = x1 - x2
+ * dx2 = x3 - x2
+ * dx3 = x0 - x1 + x2 - x3
+ * dy1 = y1 - y2
+ * dy2 = y3 - y2
+ * dy3 = y0 - y1 + y2 - y3
+ *
+ * a11 = (x1 - x0 + a31 * (w - 1) * x1) / (w - 1)
+ * a12 = (x3 - x0 + a32 * (h - 1) * x3) / (h - 1)
+ * a13 = x0
+ * a21 = (y1 - y0 + a31 * (w - 1) * y1) / (w - 1)
+ * a22 = (y3 - y0 + a32 * (h - 1) * y3) / (h - 1)
+ * a23 = y0
+ * a31 = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (w - 1)
+ * a32 = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (h - 1)
+ * a33 = 1
+ *
+ */
+template <typename T>
+__device__ void get_transform_matrix(const int transformed_width,
+                                     const int transformed_height, T roi_x[],
+                                     T roi_y[], T matrix[]) {
+  T x0 = roi_x[0];
+  T x1 = roi_x[1];
+  T x2 = roi_x[2];
+  T x3 = roi_x[3];
+  T y0 = roi_y[0];
+  T y1 = roi_y[1];
+  T y2 = roi_y[2];
+  T y3 = roi_y[3];
+
+  // Estimate the height and width of RoI
+  T len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1));
+  T len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2));
+  T len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3));
+  T len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0));
+  T estimated_height = (len2 + len4) / 2.0;
+  T estimated_width = (len1 + len3) / 2.0;
+
+  // Get the normalized height and normalized width
+  int normalized_height = transformed_height;
+  int normalized_width =
+      round(estimated_width * (normalized_height - 1) / estimated_height) + 1;
+  normalized_width = min(normalized_width, transformed_width);
+
+  T dx1 = x1 - x2;
+  T dx2 = x3 - x2;
+  T dx3 = x0 - x1 + x2 - x3;
+  T dy1 = y1 - y2;
+  T dy2 = y3 - y2;
+  T dy3 = y0 - y1 + y2 - y3;
+
+  matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) /
+              (normalized_width - 1);
+  matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) /
+              (normalized_height - 1);
+  matrix[8] = 1;
+
+  matrix[3] = (y1 - y0 + matrix[6] * (normalized_width - 1) * y1) /
+              (normalized_width - 1);
+  matrix[4] = (y3 - y0 + matrix[7] * (normalized_height - 1) * y3) /
+              (normalized_height - 1);
+  matrix[5] = y0;
+
+  matrix[0] = (x1 - x0 + matrix[6] * (normalized_width - 1) * x1) /
+              (normalized_width - 1);
+  matrix[1] = (x3 - x0 + matrix[7] * (normalized_height - 1) * x3) /
+              (normalized_height - 1);
+  matrix[2] = x0;
+}
+
+template <typename T>
+__global__ void RoiTransformKernel(const float* input_data,
+                                   const float* rois_data,
+                                   const int* roi2image_data, int num_rois,
+                                   int in_height, int in_width, int channels,
+                                   int transformed_height,
+                                   int transformed_width, float spatial_scale,
+                                   T* output_data) {
+  int output_size =
+      num_rois * transformed_height * transformed_width * channels;
+
+  CUDA_1D_KERNEL_LOOP(index, output_size) {
+    // (n, c, out_h, out_w) is an element in the transformed output
+    int out_w = idx4_4(index, num_rois, channels, transformed_height,
+                       transformed_width);
+    int out_h = idx4_3(index, num_rois, channels, transformed_height,
+                       transformed_width);
+    int c = idx4_2(index, num_rois, channels, transformed_height,
+                   transformed_width);
+    int n = idx4_1(index, num_rois, channels, transformed_height,
+                   transformed_width);
+
+    auto bottom_rois = rois_data + n * 8;
+    int roi_batch_ind = bottom_rois[0];
+    T roi_x[4];
+    T roi_y[4];
+    for (int k = 0; k < 4; ++k) {
+      roi_x[k] = bottom_rois[2 * k] * spatial_scale;
+      roi_y[k] = bottom_rois[2 * k + 1] * spatial_scale;
+    }
+
+    // Get transform matrix
+    T matrix[9];
+    get_transform_matrix<T>(transformed_width, transformed_height, roi_x, roi_y,
+                            matrix);
+
+    // Get source coords
+    T in_w;
+    T in_h;
+    get_source_coords<T>(matrix, out_w, out_h, &in_w, &in_h);
+
+    if (in_quad<T>(in_w, in_h, roi_x, roi_y)) {
+      if (GT<T>(-0.5, in_w) || GT<T>(in_w, static_cast<T>(in_width - 0.5)) ||
+          GT<T>(-0.5, in_h) || GT<T>(in_h, static_cast<T>(in_height - 0.5))) {
+        // Skip if source coords is not in input image
+        output_data[index] = 0.0;
+      } else {
+        // Perform bilinear interpolation
+        int in_n = roi2image_data[n];
+        bilinear_interpolate<T>(input_data, channels, in_width, in_height, in_n,
+                                c, in_w, in_h, output_data + index);
+      }
+
+    } else {
+      // Skip if source coords is not in quad
+      output_data[index] = 0.0;
+    }
+  }
+}
+
+template <typename T>
+class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* in = ctx.Input<framework::Tensor>("X");
+    auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
+    auto* out = ctx.Output<framework::Tensor>("Out");
+
+    auto transformed_height = ctx.Attr<int>("transformed_height");
+    auto transformed_width = ctx.Attr<int>("transformed_width");
+    auto spatial_scale = ctx.Attr<float>("spatial_scale");
+
+    auto in_dims = in->dims();
+    int batch_size = in_dims[0];
+    int channels = in_dims[1];
+    int in_height = in_dims[2];
+    int in_width = in_dims[3];
+    int rois_num = rois->dims()[0];
+
+    const T* input_data = in->data<T>();
+    T* output_data = out->mutable_data<T>(ctx.GetPlace());
+    const T* rois_data = rois->data<T>();
+
+    framework::Tensor roi2image;
+    framework::Tensor roi2image_dev;
+    roi2image.Resize({rois_num});
+    int* roi2image_data = roi2image.mutable_data<int>(platform::CPUPlace());
+    auto lod = rois->lod().back();
+    for (int i = 0; i < lod.size() - 1; ++i) {
+      for (int j = lod[i]; j < lod[i + 1]; ++j) {
+        roi2image_data[j] = i;
+      }
+    }
+    TensorCopySync(roi2image, ctx.GetPlace(), &roi2image_dev);
+
+    int out_size = rois_num * transformed_height * transformed_width * channels;
+    auto stream = ctx.cuda_device_context().stream();
+    int block = 512;
+    int grid = (out_size + block - 1) / block;
+
+    RoiTransformKernel<T><<<grid, block, 0, stream>>>(
+        input_data, rois_data, roi2image_dev.data<int>(), rois_num, in_height,
+        in_width, channels, transformed_height, transformed_width,
+        spatial_scale, output_data);
+  }
+};
+
+template <typename T>
+__device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width,
+                                  const int height) {
+  if (GT<T>(-0.5, xs) || GT<T>(xs, width - 0.5) || GT<T>(-0.5, ys) ||
+      GT<T>(ys, height - 0.5)) {
+    return 0;
+  }
+
+  if (GT<T>(0, xs)) {
+    xs = 0;
+  }
+  if (GT<T>(0, ys)) {
+    ys = 0;
+  }
+
+  int xs_floor = floor(xs);
+  int ys_floor = floor(ys);
+  int xs_ceil;
+  int ys_ceil;
+
+  if (GT_E<T>(xs_floor, width - 1)) {
+    xs_ceil = xs_floor = width - 1;
+    xs = static_cast<T>(xs_floor);
+  } else {
+    xs_ceil = xs_floor + 1;
+  }
+
+  if (GT_E(ys_floor, height - 1)) {
+    ys_ceil = ys_floor = height - 1;
+    ys = static_cast<T>(ys_floor);
+  } else {
+    ys_ceil = ys_floor + 1;
+  }
+
+  T weight = 0;
+  if (w == xs_floor) {
+    if (h == ys_floor) {
+      weight = (w + 1 - xs) * (h + 1 - ys);
+    } else if (h == ys_ceil) {
+      weight = (w + 1 - xs) * (ys + 1 - h);
+    }
+  } else if (w == xs_ceil) {
+    if (h == ys_floor) {
+      weight = (xs + 1 - w) * (h + 1 - ys);
+    } else if (h == ys_ceil) {
+      weight = (xs + 1 - w) * (ys + 1 - h);
+    }
+  }
+  return weight;
+}
+
+template <typename T>
+__global__ void RoiTransformGradKernel(
+    const size_t* lod, const T* rois_data, int batch_size, int num_rois,
+    int in_height, int in_width, int channels, int transformed_height,
+    int transformed_width, float spatial_scale, const T* out_grad_data,
+    T* in_grad_data) {
+  int input_size = batch_size * in_height * in_width * channels;
+
+  CUDA_1D_KERNEL_LOOP(index, input_size) {
+    // (n, c, h, w) coords in input
+    int in_w = idx4_4(index, batch_size, channels, in_height, in_width);
+    int in_h = idx4_3(index, batch_size, channels, in_height, in_width);
+    int c = idx4_2(index, batch_size, channels, in_height, in_width);
+    int n = idx4_1(index, batch_size, channels, in_height, in_width);
+
+    T gradient = 0.0;
+    // Accumulate gradient over all RoIs that interpolated this element
+    for (int roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) {
+      const T* rois = rois_data + roi_idx * 8;
+      T roi_x[4];
+      T roi_y[4];
+      for (int k = 0; k < 4; ++k) {
+        roi_x[k] = rois[2 * k] * spatial_scale;
+        roi_y[k] = rois[2 * k + 1] * spatial_scale;
+      }
+
+      // Get transform matrix
+      T matrix[9];
+      get_transform_matrix<T>(transformed_width, transformed_height, roi_x,
+                              roi_y, matrix);
+
+      const T* out_grad_ptr =
+          out_grad_data +
+          (roi_idx * channels + c) * transformed_height * transformed_width;
+      for (int out_h = 0; out_h < transformed_height; ++out_h) {
+        for (int out_w = 0; out_w < transformed_width; ++out_w) {
+          T src_w;
+          T src_h;
+          get_source_coords<T>(matrix, out_w, out_h, &src_w, &src_h);
+          if (in_quad<T>(src_w, src_h, roi_x, roi_y)) {
+            if (GT<T>(-0.5, src_w) ||
+                GT<T>(src_w, static_cast<T>(in_width - 0.5)) ||
+                GT<T>(-0.5, src_h) ||
+                GT<T>(src_h, static_cast<T>(in_height - 0.5))) {
+              continue;
+            }
+            T weight = get_feature_gradient<T>(src_w, src_h, in_w, in_h,
+                                               in_width, in_height);
+            gradient +=
+                out_grad_ptr[out_h * transformed_width + out_w] * weight;
+          }
+        }
+      }
+    }
+    in_grad_data[index] = gradient;
+  }
+}
+
+template <typename T>
+class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* in = ctx.Input<framework::Tensor>("X");
+    auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
+    auto* out_grad =
+        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
+    auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
+
+    auto transformed_height = ctx.Attr<int>("transformed_height");
+    auto transformed_width = ctx.Attr<int>("transformed_width");
+    auto spatial_scale = ctx.Attr<float>("spatial_scale");
+
+    auto in_dims = in->dims();
+    int batch_size = in_dims[0];
+    int channels = in_dims[1];
+    int in_height = in_dims[2];
+    int in_width = in_dims[3];
+    int rois_num = rois->dims()[0];
+
+    T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
+    const T* out_grad_data = out_grad->data<T>();
+    const T* rois_data = rois->data<T>();
+
+    auto lod = rois->lod().back();
+    auto lod_data = lod.CUDAData(ctx.GetPlace());
+
+    int in_size = in->numel();
+    auto stream = ctx.cuda_device_context().stream();
+    int block = 512;
+    int grid = (in_size + block - 1) / block;
+
+    RoiTransformGradKernel<T><<<grid, block, 0, stream>>>(
+        lod_data, rois_data, batch_size, rois_num, in_height, in_width,
+        channels, transformed_height, transformed_width, spatial_scale,
+        out_grad_data, in_grad_data);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(roi_perspective_transform,
+                        ops::CUDAROIPerspectiveTransformOpKernel<float>);
+REGISTER_OP_CUDA_KERNEL(roi_perspective_transform_grad,
+                        ops::CUDAROIPerspectiveTransformGradOpKernel<float>);
diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h
index dd5d138a1e..dd1ab85fd8 100644
--- a/paddle/fluid/operators/detection_map_op.h
+++ b/paddle/fluid/operators/detection_map_op.h
@@ -76,8 +76,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
     auto ap_type = GetAPType(ctx.Attr<std::string>("ap_type"));
     int class_num = ctx.Attr<int>("class_num");
 
-    auto& label_lod = in_label->lod();
-    auto& detect_lod = in_detect->lod();
+    auto label_lod = in_label->lod();
+    auto detect_lod = in_detect->lod();
     PADDLE_ENFORCE_EQ(label_lod.size(), 1UL,
                       "Only support one level sequence now.");
     PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(),
@@ -166,11 +166,11 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
     auto labels = framework::EigenTensor<T, 2>::From(input_label);
     auto detect = framework::EigenTensor<T, 2>::From(input_detect);
 
-    auto& label_lod = input_label.lod();
-    auto& detect_lod = input_detect.lod();
+    auto label_lod = input_label.lod();
+    auto detect_lod = input_detect.lod();
 
     int batch_size = label_lod[0].size() - 1;
-    auto& label_index = label_lod[0];
+    auto label_index = label_lod[0];
 
     for (int n = 0; n < batch_size; ++n) {
       std::map<int, std::vector<Box>> boxes;
@@ -274,6 +274,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
 
     output_true_pos->set_lod(true_pos_lod);
     output_false_pos->set_lod(false_pos_lod);
+    return;
   }
 
   void GetInputPos(const framework::Tensor& input_pos_count,
@@ -291,7 +292,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
     auto SetData = [](const framework::LoDTensor& pos_tensor,
                       std::map<int, std::vector<std::pair<T, int>>>& pos) {
       const T* pos_data = pos_tensor.data<T>();
-      auto& pos_data_lod = pos_tensor.lod()[0];
+      auto pos_data_lod = pos_tensor.lod()[0];
       for (size_t i = 0; i < pos_data_lod.size() - 1; ++i) {
         for (size_t j = pos_data_lod[i]; j < pos_data_lod[i + 1]; ++j) {
           T score = pos_data[j * 2];
@@ -316,23 +317,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
       std::map<int, std::vector<std::pair<T, int>>>* false_pos) const {
     int batch_size = gt_boxes.size();
     for (int n = 0; n < batch_size; ++n) {
-      auto& image_gt_boxes = gt_boxes[n];
-      for (auto& image_gt_box : image_gt_boxes) {
+      auto image_gt_boxes = gt_boxes[n];
+      for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) {
         size_t count = 0;
-        auto& labeled_bboxes = image_gt_box.second;
+        auto labeled_bboxes = it->second;
         if (evaluate_difficult) {
           count = labeled_bboxes.size();
         } else {
-          for (auto& box : labeled_bboxes) {
-            if (!box.is_difficult) {
-              ++count;
-            }
-          }
+          for (size_t i = 0; i < labeled_bboxes.size(); ++i)
+            if (!(labeled_bboxes[i].is_difficult)) ++count;
         }
         if (count == 0) {
           continue;
         }
-        int label = image_gt_box.first;
+        int label = it->first;
         if (label_pos_count->find(label) == label_pos_count->end()) {
           (*label_pos_count)[label] = count;
         } else {
diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc
index 1617cc1b95..c4854d50b6 100644
--- a/paddle/fluid/operators/distributed/variable_response.cc
+++ b/paddle/fluid/operators/distributed/variable_response.cc
@@ -92,9 +92,14 @@ bool VariableResponse::CopyLodTensorData(
     ::google::protobuf::io::CodedInputStream* input,
     const platform::DeviceContext& ctx, const framework::DDim& dims,
     int length) {
+  auto server_var = GetVar();
+  if (!server_var) {
+    LOG(ERROR) << "recved var should not on current server: "
+               << meta_.varname();
+    return false;
+  }
   auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
   tensor->Resize(dims);
-
   framework::LoD lod;
   for (int i = 0; i < meta_.lod_level(); ++i) {
     framework::Vector<size_t> v;
@@ -107,7 +112,6 @@ bool VariableResponse::CopyLodTensorData(
 
   void* tensor_data =
       tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
-
   if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
     return false;
   }
diff --git a/paddle/fluid/operators/extract_rows_op.cc b/paddle/fluid/operators/extract_rows_op.cc
index 3acae3bcdf..9a297d03cf 100644
--- a/paddle/fluid/operators/extract_rows_op.cc
+++ b/paddle/fluid/operators/extract_rows_op.cc
@@ -50,7 +50,7 @@ class ExtractRowsOp : public framework::OperatorBase {
     auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
     auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
 
-    auto &in_rows = in.rows();
+    auto in_rows = in.rows();
     auto out_dim = framework::make_ddim(
         std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
     auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h
index b6f4ab9377..47c771f7c5 100644
--- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h
+++ b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h
@@ -85,26 +85,59 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
                                      T *prev_output_value, int frame_size,
                                      ActivationType active_gate) {
 #ifdef __AVX__
-  __m256 r_value_update_gate;
-  __m256 r_value_reset_gate;
+  __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
+  __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f);
   __m256 r_value_reset_output;
-  __m256 r_prev_out = _mm256_set1_ps(0.0f);
-  __m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
-  __m256 *reset_gate = reinterpret_cast<__m256 *>(gate_value + frame_size);
+  __m256 r_prev_out = _mm256_set1_ps(0.0f),
+         r_prev_out_last = _mm256_set1_ps(0.0f);
+  T *update_gate = gate_value;
+  T *reset_gate = gate_value + frame_size;
+  int block = 8;
+  const int n = frame_size;
+  const int rest = n % block;
+  const int end = n - rest;
+  int i = 0;
+
+  if (rest > 0) {
+    i = n - block;
+    r_value_update_gate_last =
+        _mm256_loadu_ps((const float *)(update_gate + i));
+    r_value_reset_gate_last = _mm256_loadu_ps((const float *)(reset_gate + i));
+    if (prev_output_value) {
+      r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
+    }
+  }
 
-  for (int i = 0; i < frame_size / 8; i++) {
-    r_value_update_gate = update_gate[i];
-    r_value_reset_gate = reset_gate[i];
+  for (i = 0; i < end; i += block) {
+    r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
+    r_value_reset_gate = _mm256_loadu_ps((const float *)(reset_gate + i));
     if (prev_output_value) {
-      r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
+      r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
     }
 
     op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
                     &r_value_reset_output, active_gate);
 
-    update_gate[i] = r_value_update_gate;
-    reset_gate[i] = r_value_reset_gate;
-    (reinterpret_cast<__m256 *>(reset_output_value))[i] = r_value_reset_output;
+    _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
+                     r_value_update_gate);
+    _mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
+                     r_value_reset_gate);
+    _mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
+                     r_value_reset_output);
+  }
+
+  if (rest > 0) {
+    i = n - block;
+
+    op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last,
+                    &r_prev_out_last, &r_value_reset_output, active_gate);
+
+    _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
+                     r_value_update_gate_last);
+    _mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
+                     r_value_reset_gate_last);
+    _mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
+                     r_value_reset_output);
   }
 #endif
 }
@@ -115,26 +148,55 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
                                      T *output_value, int frame_size,
                                      ActivationType active_node) {
 #ifdef __AVX__
-  __m256 r_value_update_gate;
-  __m256 r_value_frame_state;
-  __m256 r_prev_out = _mm256_set1_ps(0.0f);
+  __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
+  __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
+  __m256 r_prev_out = _mm256_set1_ps(0.0f),
+         r_prev_out_last = _mm256_set1_ps(0.0f);
   __m256 r_output;
-  __m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
-  __m256 *frame_state = reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
+  T *update_gate = gate_value;
+  T *frame_state = gate_value + frame_size * 2;
+  int block = 8;
+  const int n = frame_size;
+  const int rest = n % block;
+  const int end = n - rest;
+  int i = 0;
+
+  if (rest > 0) {
+    i = n - block;
+    r_value_update_gate_last =
+        _mm256_loadu_ps((const float *)(update_gate + i));
+    r_value_frame_state_last =
+        _mm256_loadu_ps((const float *)(frame_state + i));
+    if (prev_output_value) {
+      r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
+    }
+  }
 
-  for (int i = 0; i < frame_size / 8; i++) {
-    r_value_update_gate = update_gate[i];
-    r_value_frame_state = frame_state[i];
+  for (i = 0; i < end; i += block) {
+    r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
+    r_value_frame_state = _mm256_loadu_ps((const float *)(frame_state + i));
     if (prev_output_value) {
-      r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
+      r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
     }
 
     op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
                     &r_output, active_node);
 
-    frame_state[i] = r_value_frame_state;
-    (reinterpret_cast<__m256 *>(output_value))[i] = r_output;
+    _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
+                     r_value_frame_state);
+    _mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
+  }
+
+  if (rest > 0) {
+    i = n - block;
+    op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
+                    &r_prev_out_last, &r_output, active_node);
+
+    _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
+                     r_value_frame_state_last);
+    _mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
   }
+
 #endif
 }
 
@@ -143,7 +205,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
                                  GRUMetaValue<T> value, int frame_size,
                                  int batch_size, ActivationType active_gate) {
   for (int b = 0; b < batch_size; b++) {
-    if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
+    if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
+        (sizeof(T) == 4)) {
       hl_avx_gru_forward_reset_output(
           op_reset_output, value.gate_value, value.reset_output_value,
           value.prev_out_value, frame_size, active_gate);
@@ -166,7 +229,8 @@ inline void forward_final_output(OpFinalOutput op_final_output,
                                  GRUMetaValue<T> value, int frame_size,
                                  int batch_size, ActivationType active_node) {
   for (int b = 0; b < batch_size; b++) {
-    if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
+    if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
+        (sizeof(T) == 4)) {
       hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
                                       value.prev_out_value, value.output_value,
                                       frame_size, active_node);
diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc
index a830dc5250..8e8baf49b2 100644
--- a/paddle/fluid/operators/math/selected_rows_functor.cc
+++ b/paddle/fluid/operators/math/selected_rows_functor.cc
@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
   framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
                                      const framework::SelectedRows& input) {
     framework::SelectedRows out;
+    (*this)(context, input, &out);
+    return out;
+  }
+
+  void operator()(const platform::CPUDeviceContext& context,
+                  const framework::SelectedRows& input,
+                  framework::SelectedRows* output) {
+    framework::SelectedRows& out = *output;
     auto input_rows = input.rows();
     std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
     std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
         out_data[out_i * input_width + j] += input_data[i * input_width + j];
       }
     }
-    return out;
   }
 };
 
diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu
index d559aaa721..b27880c232 100644
--- a/paddle/fluid/operators/math/selected_rows_functor.cu
+++ b/paddle/fluid/operators/math/selected_rows_functor.cu
@@ -60,9 +60,11 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
     auto out_place = context.GetPlace();
     PADDLE_ENFORCE(platform::is_gpu_place(out_place));
 
-    memory::Copy(boost::get<platform::CUDAPlace>(out_place), out_data,
-                 boost::get<platform::CUDAPlace>(in1_place), in1_data,
-                 in1_value.numel() * sizeof(T), context.stream());
+    memory::Copy(
+        boost::get<platform::CUDAPlace>(out_place), out_data,
+        boost::get<platform::CUDAPlace>(in1_place), in1_data,
+        in1_value.numel() * sizeof(T),
+        reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
 
     auto* in2_data = in2_value.data<T>();
     memory::Copy(boost::get<platform::CUDAPlace>(out_place),
@@ -107,7 +109,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
     PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
 
     auto& in1_value = input1.value();
-    framework::Vector<int64_t> in1_rows(input1.rows());
+    auto& in1_rows = input1.rows();
 
     int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
     PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
@@ -146,7 +148,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
     auto in1_height = input1.height();
     PADDLE_ENFORCE_EQ(in1_height, input2->height());
 
-    auto& in1_rows = input1.rows();
+    framework::Vector<int64_t> in1_rows(input1.rows());
     auto& in2_rows = *(input2->mutable_rows());
 
     auto& in1_value = input1.value();
@@ -206,7 +208,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
     PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
 
     auto& in1_value = input1.value();
-    framework::Vector<int64_t> in1_rows(input1.rows());
+    auto& in1_rows = input1.rows();
 
     int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
     PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
@@ -234,7 +236,7 @@ template <typename T, int block_size>
 __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
                                T* out, const int64_t* out_rows,
                                size_t out_rows_size, int64_t row_numel) {
-  const int ty = blockIdx.y;
+  const int ty = blockIdx.x;
   int tid = threadIdx.x;
   __shared__ size_t out_idx;
 
@@ -260,6 +262,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
   framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
                                      const framework::SelectedRows& input) {
     framework::SelectedRows out;
+    (*this)(context, input, &out);
+    return out;
+  }
+
+  void operator()(const platform::CUDADeviceContext& context,
+                  const framework::SelectedRows& input,
+                  framework::SelectedRows* output) {
+    framework::SelectedRows& out = *output;
     framework::Vector<int64_t> input_rows(input.rows());
     std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
     std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
@@ -281,16 +291,12 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
 
     const int block_size = 256;
     dim3 threads(block_size, 1);
-    dim3 grid1(1, input_rows.size());
+    dim3 grid1(input_rows.size(), 1);
 
-    MergeAddKernel<
-        T, 256><<<grid1, threads, 0,
-                  reinterpret_cast<const platform::CUDADeviceContext&>(context)
-                      .stream()>>>(
+    MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
         input_data, input_rows.CUDAData(context.GetPlace()), out_data,
         out.mutable_rows()->CUDAMutableData(context.GetPlace()),
         out.rows().size(), input_width);
-    return out;
   }
 };
 
diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h
index 18304f83f8..aa419f74fc 100644
--- a/paddle/fluid/operators/math/selected_rows_functor.h
+++ b/paddle/fluid/operators/math/selected_rows_functor.h
@@ -65,6 +65,9 @@ struct MergeAdd {
   // the input SelectedRows object.
   framework::SelectedRows operator()(const DeviceContext& context,
                                      const framework::SelectedRows& input);
+  void operator()(const DeviceContext& context,
+                  const framework::SelectedRows& input,
+                  framework::SelectedRows* output);
 };
 
 template <typename DeviceContext, typename T>
diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu
index e89b27855b..5fc50aba25 100644
--- a/paddle/fluid/operators/math/selected_rows_functor_test.cu
+++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu
@@ -20,7 +20,9 @@ limitations under the License. */
 TEST(selected_rows_functor, gpu_add) {
   paddle::platform::CUDAPlace gpu_place(0);
   paddle::platform::CPUPlace cpu_place;
-  paddle::platform::CUDADeviceContext ctx(gpu_place);
+  paddle::platform::CUDADeviceContext& ctx =
+      *reinterpret_cast<paddle::platform::CUDADeviceContext*>(
+          paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
   paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
                                        float>
       functor;
@@ -132,7 +134,9 @@ TEST(selected_rows_functor, gpu_add) {
 TEST(selected_rows_functor, gpu_add_to) {
   paddle::platform::CUDAPlace gpu_place(0);
   paddle::platform::CPUPlace cpu_place;
-  paddle::platform::CUDADeviceContext ctx(gpu_place);
+  paddle::platform::CUDADeviceContext& ctx =
+      *reinterpret_cast<paddle::platform::CUDADeviceContext*>(
+          paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
   paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
                                        float>
       functor;
diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h
index 2c4c241125..6dffe527c1 100644
--- a/paddle/fluid/operators/sum_op.h
+++ b/paddle/fluid/operators/sum_op.h
@@ -123,6 +123,7 @@ class SumKernel : public framework::OpKernel<T> {
 
       out_value->Resize(framework::make_ddim(in_dim));
       out_value->mutable_data<T>(context.GetPlace());
+
       // if all the input sparse vars are empty, no need to
       // merge these vars.
       if (first_dim == 0UL) {
diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc
index f577068d1f..1f61a0e289 100644
--- a/paddle/fluid/pybind/const_value.cc
+++ b/paddle/fluid/pybind/const_value.cc
@@ -36,7 +36,9 @@ void BindConstValue(pybind11::module* m) {
       .value("Backward", framework::OpRole::kBackward)
       .value("Optimize", framework::OpRole::kOptimize)
       .value("Loss", framework::OpRole::kLoss)
-      .value("RPC", framework::OpRole::kRPC);
+      .value("RPC", framework::OpRole::kRPC)
+      .value("Dist", framework::OpRole::kDist)
+      .value("LRSched", framework::OpRole::kLRSched);
 
   op_proto_and_checker_maker.def(
       "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py
index 1ca2ac2ddc..9e4a5ae8ba 100644
--- a/python/paddle/fluid/__init__.py
+++ b/python/paddle/fluid/__init__.py
@@ -46,7 +46,7 @@ from . import transpiler
 from .param_attr import ParamAttr, WeightNormParamAttr
 from .data_feeder import DataFeeder
 from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
-from .transpiler import DistributeTranspiler, InferenceTranspiler, \
+from .transpiler import DistributeTranspiler, \
     memory_optimize, release_memory, DistributeTranspilerConfig
 from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
 from . import clip
diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py
index 0abbb68151..d7e5e47048 100644
--- a/python/paddle/fluid/framework.py
+++ b/python/paddle/fluid/framework.py
@@ -1509,6 +1509,30 @@ class Program(object):
         self._op_role_var = []
         self._current_role = OpRole.Forward
 
+    @contextlib.contextmanager
+    def _lr_schedule_guard(self):
+        """
+        A with guard to set :code:`LRSched` :code:`OpRole` and
+        :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
+        set to the target learning rate.
+
+        Notes: This is a very low level API. Users should not use it directly.
+
+
+        Examples:
+
+            >>> p, g = backward(...)
+            >>> with program.lr_schedule_guard():
+            >>>     lr = lr * decay
+        """
+        OpRole = core.op_proto_and_checker_maker.OpRole
+        self._current_role = OpRole.LRSched
+        # TODO(typhoonzero): how to set target learning rate var
+        self._op_role_var = []
+        yield
+        self._op_role_var = []
+        self._current_role = OpRole.Forward
+
     def __str__(self):
         """
         Get the protobuf debug string of this Program.
diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py
index 7a7a0078a5..a26b8df5a2 100644
--- a/python/paddle/fluid/initializer.py
+++ b/python/paddle/fluid/initializer.py
@@ -74,7 +74,7 @@ class Initializer(object):
     directly, but need to use one of its implementations.
     """
 
-    def __init_(self):
+    def __init__(self):
         pass
 
     def __call__(self, param, block):
@@ -293,7 +293,7 @@ class TruncatedNormalInitializer(Initializer):
         assert loc is not None
         assert scale is not None
         assert seed is not None
-        super(NormalInitializer, self).__init__()
+        super(TruncatedNormalInitializer, self).__init__()
         self._mean = loc
         self._std_dev = scale
         self._seed = seed
diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py
index 78bb8a1a0a..e703e5ac79 100644
--- a/python/paddle/fluid/io.py
+++ b/python/paddle/fluid/io.py
@@ -27,8 +27,7 @@ from . import core
 
 __all__ = [
     'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
-    'load_persistables', 'save_inference_model', 'load_inference_model',
-    'get_inference_program'
+    'load_persistables', 'save_inference_model', 'load_inference_model'
 ]
 
 
@@ -504,23 +503,6 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
         filename=filename)
 
 
-def get_inference_program(target_vars, main_program=None):
-    if main_program is None:
-        main_program = default_main_program()
-    if not isinstance(target_vars, list):
-        target_vars = [target_vars]
-    vars = []
-    for var in target_vars:
-        if isinstance(var, Evaluator):
-            vars.extend(var.states)
-            vars.extend(var.metrics)
-        else:
-            vars.append(var)
-    pruned_program = main_program._prune(targets=vars)
-    inference_program = pruned_program._inference_optimize()
-    return inference_program
-
-
 def prepend_feed_ops(inference_program,
                      feed_target_names,
                      feed_holder_name='feed'):
diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py
index 8e86bec860..574d0d727c 100644
--- a/python/paddle/fluid/layers/detection.py
+++ b/python/paddle/fluid/layers/detection.py
@@ -39,6 +39,7 @@ __all__ = [
     'detection_map',
     'rpn_target_assign',
     'anchor_generator',
+    'roi_perspective_transform',
     'generate_proposal_labels',
     'generate_proposals',
 ]
@@ -1262,6 +1263,54 @@ def anchor_generator(input,
     return anchor, var
 
 
+def roi_perspective_transform(input,
+                              rois,
+                              transformed_height,
+                              transformed_width,
+                              spatial_scale=1.0):
+    """
+    ROI perspective transform op.
+
+    Args:
+        input (Variable): The input of ROIPerspectiveTransformOp. The format of 
+                          input tensor is NCHW. Where N is batch size, C is the
+                          number of input channels, H is the height of the feature,
+                          and W is the width of the feature.
+        rois (Variable):  ROIs (Regions of Interest) to be transformed. It should be
+                          a 2-D LoDTensor of shape (num_rois, 8). Given as 
+                          [[x1, y1, x2, y2, x3, y3, x4, y4], ...], (x1, y1) is the 
+                          top left coordinates, and (x2, y2) is the top right 
+                          coordinates, and (x3, y3) is the bottom right coordinates, 
+                          and (x4, y4) is the bottom left coordinates.
+        transformed_height (integer): The height of transformed output.
+        transformed_height (integer): The width of transformed output.
+        spatial_scale (float): Spatial scale factor to scale ROI coords. Default: 1.0
+
+    Returns:
+        Variable: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape 
+                  (num_rois, channels, transformed_h, transformed_w).
+
+    Examples:
+        .. code-block:: python
+
+            out = fluid.layers.roi_perspective_transform(input, rois, 7, 7, 1.0)
+    """
+    helper = LayerHelper('roi_perspective_transform', **locals())
+    dtype = helper.input_dtype()
+    out = helper.create_tmp_variable(dtype)
+    helper.append_op(
+        type="roi_perspective_transform",
+        inputs={"X": input,
+                "ROIs": rois},
+        outputs={"Out": out},
+        attrs={
+            "transformed_height": transformed_height,
+            "transformed_width": transformed_width,
+            "spatial_scale": spatial_scale
+        })
+    return out
+
+
 def generate_proposal_labels(rpn_rois,
                              gt_classes,
                              is_crowd,
diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py
index be368007dd..2b947ca9e8 100644
--- a/python/paddle/fluid/layers/learning_rate_scheduler.py
+++ b/python/paddle/fluid/layers/learning_rate_scheduler.py
@@ -27,7 +27,7 @@ from . import nn
 from . import ops
 from . import tensor
 from ..initializer import init_on_cpu
-from ..framework import default_main_program, Parameter
+from ..framework import default_main_program, Parameter, unique_name
 
 __all__ = [
     'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
@@ -63,11 +63,12 @@ def noam_decay(d_model, warmup_steps):
     Returns:
         The decayed learning rate.
     """
-    global_step = _decay_step_counter(1)
+    with default_main_program()._lr_schedule_guard():
+        global_step = _decay_step_counter(1)
 
-    a = global_step**-0.5
-    b = (warmup_steps**-1.5) * global_step
-    lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
+        a = global_step**-0.5
+        b = (warmup_steps**-1.5) * global_step
+        lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
 
     return lr_value
 
@@ -108,14 +109,15 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
           sgd_optimizer.minimize(avg_cost)
 
     """
-    global_step = _decay_step_counter()
+    with default_main_program()._lr_schedule_guard():
+        global_step = _decay_step_counter()
 
-    div_res = global_step / decay_steps
-    if staircase:
-        div_res = ops.floor(div_res)
-    decayed_lr = learning_rate * (decay_rate**div_res)
+        div_res = global_step / decay_steps
+        if staircase:
+            div_res = ops.floor(div_res)
+        decayed_lr = learning_rate * (decay_rate**div_res)
 
-    return decayed_lr
+        return decayed_lr
 
 
 def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
@@ -136,14 +138,15 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
     Returns:
         The decayed learning rate
     """
-    global_step = _decay_step_counter()
+    with default_main_program()._lr_schedule_guard():
+        global_step = _decay_step_counter()
 
-    div_res = global_step / decay_steps
-    if staircase:
-        div_res = ops.floor(div_res)
-    decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
+        div_res = global_step / decay_steps
+        if staircase:
+            div_res = ops.floor(div_res)
+        decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
 
-    return decayed_lr
+        return decayed_lr
 
 
 def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
@@ -181,15 +184,16 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
                     staircase=True))
           sgd_optimizer.minimize(avg_cost)
     """
-    global_step = _decay_step_counter()
+    with default_main_program()._lr_schedule_guard():
+        global_step = _decay_step_counter()
 
-    div_res = global_step / decay_steps
-    if staircase:
-        div_res = ops.floor(div_res)
+        div_res = global_step / decay_steps
+        if staircase:
+            div_res = ops.floor(div_res)
 
-    decayed_lr = learning_rate / (1 + decay_rate * div_res)
+        decayed_lr = learning_rate / (1 + decay_rate * div_res)
 
-    return decayed_lr
+        return decayed_lr
 
 
 def polynomial_decay(learning_rate,
@@ -220,25 +224,28 @@ def polynomial_decay(learning_rate,
     Returns:
         Variable: The decayed learning rate
     """
-    global_step = _decay_step_counter()
-
-    if cycle:
-        div_res = ops.ceil(global_step / decay_steps)
-        zero_var = tensor.fill_constant(shape=[1], dtype='float32', value=0.0)
-        one_var = tensor.fill_constant(shape=[1], dtype='float32', value=1.0)
-
-        with control_flow.Switch() as switch:
-            with switch.case(global_step == zero_var):
-                tensor.assign(input=one_var, output=div_res)
-        decay_steps = decay_steps * div_res
-    else:
-        decay_steps_var = tensor.fill_constant(
-            shape=[1], dtype='float32', value=float(decay_steps))
-        global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
+    with default_main_program()._lr_schedule_guard():
+        global_step = _decay_step_counter()
+
+        if cycle:
+            div_res = ops.ceil(global_step / decay_steps)
+            zero_var = tensor.fill_constant(
+                shape=[1], dtype='float32', value=0.0)
+            one_var = tensor.fill_constant(
+                shape=[1], dtype='float32', value=1.0)
+
+            with control_flow.Switch() as switch:
+                with switch.case(global_step == zero_var):
+                    tensor.assign(input=one_var, output=div_res)
+            decay_steps = decay_steps * div_res
+        else:
+            decay_steps_var = tensor.fill_constant(
+                shape=[1], dtype='float32', value=float(decay_steps))
+            global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
 
-    decayed_lr = (learning_rate - end_learning_rate) * \
-        ((1 - global_step / decay_steps) ** power) + end_learning_rate
-    return decayed_lr
+        decayed_lr = (learning_rate - end_learning_rate) * \
+            ((1 - global_step / decay_steps) ** power) + end_learning_rate
+        return decayed_lr
 
 
 def piecewise_decay(boundaries, values):
@@ -266,34 +273,36 @@ def piecewise_decay(boundaries, values):
 
 
     """
+    with default_main_program()._lr_schedule_guard():
+        if len(values) - len(boundaries) != 1:
+            raise ValueError("len(values) - len(boundaries) should be 1")
 
-    if len(values) - len(boundaries) != 1:
-        raise ValueError("len(values) - len(boundaries) should be 1")
-
-    global_step = _decay_step_counter()
+        global_step = _decay_step_counter()
 
-    lr = tensor.create_global_var(
-        shape=[1],
-        value=0.0,
-        dtype='float32',
-        persistable=True,
-        name="learning_rate")
+        lr = tensor.create_global_var(
+            shape=[1],
+            value=0.0,
+            dtype='float32',
+            persistable=True,
+            name="learning_rate")
 
-    with control_flow.Switch() as switch:
-        for i in range(len(boundaries)):
-            boundary_val = tensor.fill_constant(
+        with control_flow.Switch() as switch:
+            for i in range(len(boundaries)):
+                boundary_val = tensor.fill_constant(
+                    shape=[1],
+                    dtype='float32',
+                    value=float(boundaries[i]),
+                    force_cpu=True)
+                value_var = tensor.fill_constant(
+                    shape=[1], dtype='float32', value=float(values[i]))
+                with switch.case(global_step < boundary_val):
+                    tensor.assign(value_var, lr)
+            last_value_var = tensor.fill_constant(
                 shape=[1],
                 dtype='float32',
-                value=float(boundaries[i]),
-                force_cpu=True)
-            value_var = tensor.fill_constant(
-                shape=[1], dtype='float32', value=float(values[i]))
-            with switch.case(global_step < boundary_val):
-                tensor.assign(value_var, lr)
-        last_value_var = tensor.fill_constant(
-            shape=[1], dtype='float32', value=float(values[len(values) - 1]))
-        with switch.default():
-            tensor.assign(last_value_var, lr)
+                value=float(values[len(values) - 1]))
+            with switch.default():
+                tensor.assign(last_value_var, lr)
 
     return lr
 
diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py
index ef7b16a19e..ad09005d86 100644
--- a/python/paddle/fluid/optimizer.py
+++ b/python/paddle/fluid/optimizer.py
@@ -43,11 +43,7 @@ class Optimizer(object):
     but need to use one of it's implementation.
     """
 
-    def __init__(self,
-                 learning_rate,
-                 regularization=None,
-                 LARS_weight_decay=0.0,
-                 name=None):
+    def __init__(self, learning_rate, regularization=None, name=None):
         if not isinstance(learning_rate, float) and \
                 not isinstance(learning_rate, framework.Variable):
             raise TypeError("learning rate should be float or Variable")
@@ -68,7 +64,6 @@ class Optimizer(object):
         # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
         self._accumulators = defaultdict(lambda: dict())
         self.helper = None
-        self._LARS_weight_decay = LARS_weight_decay
 
     def _create_global_learning_rate(self):
         lr = self._global_learning_rate()
@@ -109,7 +104,6 @@ class Optimizer(object):
         param = param_and_grad[0]
         param_lr = param.optimize_attr['learning_rate']
         if type(param_lr) == Variable:
-            # param learning rate has been updated (LARS)
             print("returns updated param lr ", param_lr)
             return param_lr
         else:
@@ -227,10 +221,6 @@ class Optimizer(object):
             self._create_accumulators(loss.block,
                                       [p[0] for p in parameters_and_grads])
             self._create_global_learning_rate()
-            if self._LARS_weight_decay > 0.0:
-                layers.append_LARS(parameters_and_grads,
-                                   self._global_learning_rate(),
-                                   self._LARS_weight_decay)
 
             optimize_ops = []
             for param_and_grad in parameters_and_grads:
@@ -287,6 +277,9 @@ class SGDOptimizer(Optimizer):
     Args:
         learning_rate (float|Variable): the learning rate used to update parameters. \
         Can be a float value or a Variable with one float value as data element.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -295,10 +288,12 @@ class SGDOptimizer(Optimizer):
             sgd_optimizer.minimize(cost)
     """
 
-    def __init__(self, learning_rate, **kwargs):
+    def __init__(self, learning_rate, regularization=None, name=None):
         assert learning_rate is not None
         super(SGDOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "sgd"
 
     def _append_optimize_op(self, block, param_and_grad):
@@ -343,6 +338,9 @@ class MomentumOptimizer(Optimizer):
         Can be a float value or a Variable with one float value as data element.
         momentum (float): momentum factor
         use_nesterov (bool): enables Nesterov momentum
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -352,11 +350,18 @@ class MomentumOptimizer(Optimizer):
     """
     _velocity_acc_str = "velocity"
 
-    def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs):
+    def __init__(self,
+                 learning_rate,
+                 momentum,
+                 use_nesterov=False,
+                 regularization=None,
+                 name=None):
         assert learning_rate is not None
         assert momentum is not None
         super(MomentumOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "momentum"
         self._momentum = momentum
         self._use_nesterov = bool(use_nesterov)
@@ -412,6 +417,9 @@ class AdagradOptimizer(Optimizer):
         learning_rate (float|Variable): the learning rate used to update parameters. \
         Can be a float value or a Variable with one float value as data element.
         epsilon (float): a small float value for numerical stability.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -421,11 +429,17 @@ class AdagradOptimizer(Optimizer):
     """
     _moment_acc_str = "moment"
 
-    def __init__(self, learning_rate, epsilon=1.0e-6, **kwargs):
+    def __init__(self,
+                 learning_rate,
+                 epsilon=1.0e-6,
+                 regularization=None,
+                 name=None):
         assert learning_rate is not None
         assert epsilon is not None
         super(AdagradOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "adagrad"
         self._epsilon = epsilon
 
@@ -485,6 +499,9 @@ class AdamOptimizer(Optimizer):
         beta1 (float): The exponential decay rate for the 1st moment estimates.
         beta2 (float): The exponential decay rate for the 2nd moment estimates.
         epsilon (float): a small float value for numerical stability.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -503,13 +520,16 @@ class AdamOptimizer(Optimizer):
                  beta1=0.9,
                  beta2=0.999,
                  epsilon=1e-8,
-                 **kwargs):
+                 regularization=None,
+                 name=None):
         assert learning_rate is not None
         assert beta1 is not None
         assert beta2 is not None
         assert epsilon is not None
         super(AdamOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "adam"
         self._beta1 = beta1
         self._beta2 = beta2
@@ -629,6 +649,9 @@ class AdamaxOptimizer(Optimizer):
         beta1 (float): The exponential decay rate for the 1st moment estimates.
         beta2 (float): The exponential decay rate for the 2nd moment estimates.
         epsilon (float): a small float value for numerical stability.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -645,13 +668,16 @@ class AdamaxOptimizer(Optimizer):
                  beta1=0.9,
                  beta2=0.999,
                  epsilon=1e-8,
-                 **kwargs):
+                 regularization=None,
+                 name=None):
         assert learning_rate is not None
         assert beta1 is not None
         assert beta2 is not None
         assert epsilon is not None
         super(AdamaxOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "adamax"
         self._beta1 = beta1
         self._beta2 = beta2
@@ -742,6 +768,9 @@ class DecayedAdagradOptimizer(Optimizer):
         Can be a float value or a Variable with one float value as data element.
         decay (float): decay rate.
         epsilon (float): a small float value for numerical stability.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -751,13 +780,20 @@ class DecayedAdagradOptimizer(Optimizer):
     """
     _moment_acc_str = "moment"
 
-    def __init__(self, learning_rate, decay=0.95, epsilon=1.0e-6, **kwargs):
+    def __init__(self,
+                 learning_rate,
+                 decay=0.95,
+                 epsilon=1.0e-6,
+                 regularization=None,
+                 name=None):
         assert learning_rate is not None
         assert decay is not None
         assert epsilon is not None
 
         super(DecayedAdagradOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "decayed_adagrad"
         self._decay = decay
         self._epsilon = epsilon
@@ -811,6 +847,9 @@ class AdadeltaOptimizer(Optimizer):
         learning_rate(float): global learning rate
         rho(float): rho in equation
         epsilon(float): epsilon in equation
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Examples:
         .. code-block:: python
@@ -823,7 +862,12 @@ class AdadeltaOptimizer(Optimizer):
     _avg_squared_grad_acc_str = "_avg_squared_grad"
     _avg_squared_update_acc_str = "_avg_squared_update"
 
-    def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs):
+    def __init__(self,
+                 learning_rate,
+                 epsilon=1.0e-6,
+                 rho=0.95,
+                 regularization=None,
+                 name=None):
         if learning_rate is None:
             raise ValueError("learning_rate is not set.")
         if epsilon is None:
@@ -831,7 +875,9 @@ class AdadeltaOptimizer(Optimizer):
         if rho is None:
             raise ValueError("rho is not set.")
         super(AdadeltaOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         self.type = "adadelta"
         self._epsilon = epsilon
         self._rho = rho
@@ -932,6 +978,9 @@ class RMSPropOptimizer(Optimizer):
             the gradient; if False, by the uncentered second moment. Setting this to
             True may help with training, but is slightly more expensive in terms of
             computation and memory. Defaults to False.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Raises:
         ValueError: If learning_rate, rho, epsilon, momentum are None.
@@ -953,9 +1002,12 @@ class RMSPropOptimizer(Optimizer):
                  epsilon=1.0e-6,
                  momentum=0.0,
                  centered=False,
-                 **kwargs):
+                 regularization=None,
+                 name=None):
         super(RMSPropOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         if learning_rate is None:
             raise ValueError("learning_rate is not set.")
         if rho is None:
@@ -1061,6 +1113,9 @@ class FtrlOptimizer(Optimizer):
         l1 (float):
         l2 (float):
         lr_power (float):
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
 
     Raises:
         ValueError: If learning_rate, rho, epsilon, momentum are None.
@@ -1075,9 +1130,17 @@ class FtrlOptimizer(Optimizer):
     _squared_acc_str = "squared"
     _linear_acc_str = "linear"
 
-    def __init__(self, learning_rate, l1=0.0, l2=0.0, lr_power=-0.5, **kwargs):
+    def __init__(self,
+                 learning_rate,
+                 l1=0.0,
+                 l2=0.0,
+                 lr_power=-0.5,
+                 regularization=None,
+                 name=None):
         super(FtrlOptimizer, self).__init__(
-            learning_rate=learning_rate, **kwargs)
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
         if learning_rate is None:
             raise ValueError("learning_rate is not set.")
 
@@ -1155,7 +1218,9 @@ class ModelAverage(Optimizer):
         average_window_rate: The rate of average window.
         min_average_window: The minimum size of average window.
         max_average_window: The maximum size of average window.
-
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
     Examples:
 
       .. code-block:: python
@@ -1178,8 +1243,10 @@ class ModelAverage(Optimizer):
                  average_window_rate,
                  min_average_window=10000,
                  max_average_window=10000,
-                 **kwargs):
-        super(ModelAverage, self).__init__(0.0, **kwargs)
+                 regularization=None,
+                 name=None):
+        super(ModelAverage, self).__init__(
+            0.0, regularization=regularization, name=name)
         self.average_window = average_window_rate
         self.min_average_window = min_average_window
         self.max_average_window = max_average_window
diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py
index 8f4678649f..a4336e955f 100644
--- a/python/paddle/fluid/regularizer.py
+++ b/python/paddle/fluid/regularizer.py
@@ -190,14 +190,11 @@ class L1DecayRegularizer(WeightDecayRegularizer):
     Examples:
         .. code-block:: python
 
-            program = fluid.framework.Program()
-            block = program.global_block()
-            mul_x = block.create_parameter(
-                dtype="float32",
-                shape=[5, 10],
-                lod_level=0,
-                name="mul.x",
-                regularizer=fluid.regularizer.L1DecayRegularizer(0.5))
+            optimizer = fluid.optimizer.Adagrad(
+                learning_rate=1e-4,
+                regularization=fluid.regularizer.L1DecayRegularizer(
+                    regularization_coeff=0.1))
+            optimizer.minimize(avg_cost)
     """
 
     def __init__(self, regularization_coeff=0.0):
diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py
index 135f11d24c..4b4f3e4037 100644
--- a/python/paddle/fluid/tests/book/test_recognize_digits.py
+++ b/python/paddle/fluid/tests/book/test_recognize_digits.py
@@ -99,7 +99,7 @@ def train(nn_type,
 
     test_program = fluid.default_main_program().clone(for_test=True)
 
-    optimizer = fluid.optimizer.Adam(learning_rate=0.001, LARS_weight_decay=0.3)
+    optimizer = fluid.optimizer.Adam(learning_rate=0.001)
     optimizer.minimize(avg_loss)
 
     place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt
index 88d36fe639..d02c890209 100644
--- a/python/paddle/fluid/tests/unittests/CMakeLists.txt
+++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt
@@ -34,12 +34,13 @@ if(APPLE)
         list(REMOVE_ITEM TEST_OPS test_desc_clone)
         list(REMOVE_ITEM TEST_OPS test_program_code)
     endif(NOT WITH_DISTRIBUTE)
-    message(WARNING "These tests has been disabled in OSX before being fixed: \n test_detection_map_op \n test_dist_se_resnext")
+    message(WARNING "These tests has been disabled in OSX before being fixed: \n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext")
     # this op is not support on mac
     list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
     # TODO: add the unitest back when it fixed
     list(REMOVE_ITEM TEST_OPS test_detection_map_op)
     list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
+    list(REMOVE_ITEM TEST_OPS test_fuse_elewise_add_act_pass)
 endif()
 
 function(py_test_modules TARGET_NAME)
@@ -79,7 +80,8 @@ if(WITH_DISTRIBUTE)
         py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL)
     endif(NOT APPLE)
     py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
-    py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL)
+    #FIXME(gongwb): random fails.
+    #py_test_modules(test_dist_transformer MODULES test_dist_transformer SERIAL)
 endif()
 py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
 py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py
index 3ec79f8ef6..175bd130e5 100644
--- a/python/paddle/fluid/tests/unittests/dist_transformer.py
+++ b/python/paddle/fluid/tests/unittests/dist_transformer.py
@@ -437,13 +437,8 @@ def split_data(data, num_part):
     ]
 
 
-def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
+def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names,
                  sum_cost, token_num):
-    # Context to do validation.
-    test_program = train_progm.clone()
-    with fluid.program_guard(test_program):
-        test_program = fluid.io.get_inference_program([avg_cost])
-
     val_data = DataReader(
         src_vocab_fpath=TrainTaskConfig.src_vocab_fpath,
         trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath,
@@ -505,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
 
 
 def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
-               token_num, predict):
+               token_num, predict, test_program):
     # Initialize the parameters.
     if TrainTaskConfig.ckpt_path:
         lr_scheduler.current_steps = TrainTaskConfig.start_step
@@ -554,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
                                                                              -1] + label_data_input_fields
 
     if TrainTaskConfig.val_file_pattern is not None:
-        test = test_context(train_progm, avg_cost, train_exe, dev_count,
+        test = test_context(test_program, avg_cost, train_exe, dev_count,
                             data_input_names, sum_cost, token_num)
 
     # the best cross-entropy value with label smoothing
@@ -1647,6 +1642,8 @@ def get_model(is_dist, is_async):
     local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
                                                TrainTaskConfig.warmup_steps,
                                                TrainTaskConfig.learning_rate)
+    # Context to do validation.
+    test_program = fluid.default_main_program().clone(for_test=True)
 
     if not is_dist:
         optimizer = fluid.optimizer.Adam(
@@ -1671,7 +1668,7 @@ def get_model(is_dist, is_async):
             epsilon=TrainTaskConfig.eps)
         optimizer.minimize(sum_cost)
 
-    return sum_cost, avg_cost, predict, token_num, local_lr_scheduler
+    return sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program
 
 
 def update_args():
@@ -1705,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase):
     def run_trainer(self, use_cuda, args):
         place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
         TrainTaskConfig.use_gpu = use_cuda
-        sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
+        sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
             args.is_dist, not args.sync_mode)
 
         if args.is_dist:
@@ -1726,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase):
         TrainTaskConfig.local = not args.is_dist
 
         train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost,
-                   local_lr_scheduler, token_num, predict)
+                   local_lr_scheduler, token_num, predict, test_program)
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py
index e97643cdde..b5549c507e 100644
--- a/python/paddle/fluid/tests/unittests/op_test.py
+++ b/python/paddle/fluid/tests/unittests/op_test.py
@@ -345,7 +345,7 @@ class OpTest(unittest.TestCase):
                         actual_t, expect_t, atol=atol, equal_nan=equal_nan),
                     "Output (" + out_name + ") has diff at " + str(place) +
                     "\nExpect " + str(expect_t) + "\n" + "But Got" +
-                    str(actual_t) + " in class " + self.__class__.__name__)
+                    str(actual_t))
                 if isinstance(expect, tuple):
                     self.assertListEqual(actual.recursive_sequence_lengths(),
                                          expect[1], "Output (" + out_name +
diff --git a/python/paddle/fluid/tests/unittests/test_detection_map_op.py b/python/paddle/fluid/tests/unittests/test_detection_map_op.py
index 0c5343a97d..f6eb8f2c6d 100644
--- a/python/paddle/fluid/tests/unittests/test_detection_map_op.py
+++ b/python/paddle/fluid/tests/unittests/test_detection_map_op.py
@@ -20,7 +20,6 @@ import six
 import sys
 import collections
 import math
-import paddle.fluid as fluid
 from op_test import OpTest
 
 
@@ -33,7 +32,7 @@ class TestDetectionMAPOp(OpTest):
         self.detect = np.array(self.detect).astype('float32')
         self.mAP = np.array(self.mAP).astype('float32')
 
-        if len(self.class_pos_count) > 0:
+        if (len(self.class_pos_count) > 0):
             self.class_pos_count = np.array(self.class_pos_count).astype(
                 'int32')
             self.true_pos = np.array(self.true_pos).astype('float32')
@@ -274,7 +273,7 @@ class TestDetectionMAPOp11Point(TestDetectionMAPOp):
 class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp):
     def init_test_case(self):
         super(TestDetectionMAPOpMultiBatch, self).init_test_case()
-        self.class_pos_count = [0, 2, 1, 0]
+        self.class_pos_count = [0, 2, 1]
         self.true_pos_lod = [[0, 3, 2]]
         self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]]
         self.false_pos_lod = [[0, 3, 2]]
diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py
index 59a137c18c..09b1c546e4 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py
@@ -22,7 +22,7 @@ class TestDistMnist2x2(TestDistBase):
         self._sync_mode = True
         self._use_reduce = False
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
         self.check_with_place("dist_mnist.py", delta=1e-7)
 
 
@@ -31,7 +31,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
         self._sync_mode = True
         self._mem_opt = True
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
         self.check_with_place("dist_mnist.py", delta=1e-7)
 
 
@@ -40,7 +40,7 @@ class TestDistMnistAsync(TestDistBase):
         self._sync_mode = False
         self._use_reduce = False
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
         self.check_with_place("dist_mnist.py", delta=200)
 
 
diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
index c0e9fa38e7..7c3ed09168 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
@@ -21,7 +21,16 @@ class TestDistSeResneXt2x2(TestDistBase):
     def _setup_config(self):
         self._sync_mode = True
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
+        self.check_with_place("dist_se_resnext.py", delta=1e-7)
+
+
+class TestDistseResnXt2x2WithMemopt(TestDistBase):
+    def _setup_config(self):
+        self._sync_mode = True
+        self._mem_opt = True
+
+    def test_dist_train(self):
         self.check_with_place("dist_se_resnext.py", delta=1e-7)
 
 
@@ -29,7 +38,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
     def _setup_config(self):
         self._sync_mode = False
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
         self.check_with_place("dist_se_resnext.py", delta=100)
 
 
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transformer.py b/python/paddle/fluid/tests/unittests/test_dist_transformer.py
index 47083ca7e9..47e8dfaf03 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_transformer.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_transformer.py
@@ -59,7 +59,7 @@ class TestDistTransformer2x2Sync(TestDistBase):
     def _setup_config(self):
         self._sync_mode = True
 
-    def test_transformer(self):
+    def test_dist_train(self):
         download_files()
         self.check_with_place("dist_transformer.py", delta=1e-5)
 
@@ -68,7 +68,7 @@ class TestDistTransformer2x2Async(TestDistBase):
     def _setup_config(self):
         self._sync_mode = False
 
-    def test_transformer(self):
+    def test_dist_train(self):
         download_files()
         self.check_with_place("dist_transformer.py", delta=1.0)
 
diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py
index 9a3e92e8d7..33b39b262b 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py
@@ -17,19 +17,28 @@ import unittest
 from test_dist_base import TestDistBase
 
 
-class TestDistSeResneXt2x2(TestDistBase):
+class TestDistW2V2x2(TestDistBase):
     def _setup_config(self):
         self._sync_mode = True
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
         self.check_with_place("dist_word2vec.py", delta=1e-4)
 
 
-class TestDistSeResneXt2x2Async(TestDistBase):
+class TestDistW2V2x2WithMemOpt(TestDistBase):
+    def _setup_config(self):
+        self._sync_mode = True
+        self._mem_opt = True
+
+    def test_dist_train(self):
+        self.check_with_place("dist_word2vec.py", delta=1e-4)
+
+
+class TestDistW2V2x2Async(TestDistBase):
     def _setup_config(self):
         self._sync_mode = False
 
-    def test_se_resnext(self):
+    def test_dist_train(self):
         self.check_with_place("dist_word2vec.py", delta=1)
 
 
diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py
index 9f6f03f9cf..f61a447fd7 100644
--- a/python/paddle/fluid/tests/unittests/test_gru_op.py
+++ b/python/paddle/fluid/tests/unittests/test_gru_op.py
@@ -30,7 +30,8 @@ def gru(
         bias,  # 1 x 3D
         is_reverse,
         act_state,
-        act_gate):
+        act_gate,
+        dtype='float32'):
     def _seq_to_batch(lod, is_reverse):
         idx_in_seq_list = []
         seq_lens = lod[0]
@@ -71,10 +72,10 @@ def gru(
     T = sum(lod[0])
     N = len(lod[0])
     D = weight.shape[0]
-    batch_gate = np.zeros((T, 3 * D), dtype='float64')
-    batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
-    batch_hidden = np.zeros((T, D), dtype='float64')
-    hidden = np.zeros((T, D), dtype='float64')
+    batch_gate = np.zeros((T, 3 * D), dtype=dtype)
+    batch_reset_hidden_prev = np.zeros((T, D), dtype=dtype)
+    batch_hidden = np.zeros((T, D), dtype=dtype)
+    hidden = np.zeros((T, D), dtype=dtype)
 
     idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
     h_p = h0[sorted_seqs]
@@ -108,23 +109,24 @@ class TestGRUOp(OpTest):
         self.with_bias = True
         self.act_state = 'tanh'
         self.act_gate = 'sigmoid'
+        self.dtype = 'float64'
         self.set_confs()
 
         T = sum(self.lod[0])
         N = len(self.lod[0])
 
-        input = np.random.rand(T, 3 * self.D).astype('float64')
-        weight = np.random.rand(self.D, 3 * self.D).astype('float64')
+        input = np.random.rand(T, 3 * self.D).astype(self.dtype)
+        weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype)
         bias = np.random.rand(
-            1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
-                (1, 3 * self.D), dtype='float64')
+            1, 3 * self.D).astype(self.dtype) if self.with_bias else np.zeros(
+                (1, 3 * self.D), dtype=self.dtype)
         h0 = np.random.rand(
-            N, self.D).astype('float64') if self.with_h0 else np.zeros(
-                (N, self.D), dtype='float64')
+            N, self.D).astype(self.dtype) if self.with_h0 else np.zeros(
+                (N, self.D), dtype=self.dtype)
 
         batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
             input, self.lod, h0, weight, bias, self.is_reverse,
-            ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
+            ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
         self.inputs = {'Input': (input, self.lod), 'Weight': weight}
 
         if self.with_bias:
@@ -153,6 +155,12 @@ class TestGRUOp(OpTest):
         self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
 
 
+class TestGRUOp2(TestGRUOp):
+    def set_confs(self):
+        self.D = 19
+        self.dtype = 'float32'
+
+
 class TestGRUOpNoInitial(TestGRUOp):
     def set_confs(self):
         self.with_h0 = False
diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py
index 807c114b5b..3536b876c9 100644
--- a/python/paddle/fluid/tests/unittests/test_layers.py
+++ b/python/paddle/fluid/tests/unittests/test_layers.py
@@ -573,6 +573,16 @@ class TestBook(unittest.TestCase):
             self.assertIsNotNone(out)
         print(str(program))
 
+    def test_roi_perspective_transform(self):
+        program = Program()
+        with program_guard(program):
+            x = layers.data(name="x", shape=[256, 30, 30], dtype="float32")
+            rois = layers.data(
+                name="rois", shape=[8], dtype="float32", lod_level=1)
+            output = layers.roi_perspective_transform(x, rois, 7, 7, 0.6)
+            self.assertIsNotNone(output)
+        print(str(program))
+
     def test_sequence_enumerate(self):
         program = Program()
         with program_guard(program):
diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py
new file mode 100644
index 0000000000..de67513156
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py
@@ -0,0 +1,306 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License")
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUWARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import numpy as np
+import math
+import sys
+import paddle.compat as cpt
+from op_test import OpTest
+from math import sqrt
+from math import floor
+
+
+def gt_e(a, b):
+    return a > b or abs(a - b) < 1e-4
+
+
+def gt(a, b):
+    return (a - b) > 1e-4
+
+
+def lt_e(a, b):
+    return a < b or abs(a - b) < 1e-4
+
+
+def in_quad(x, y, roi_x, roi_y):
+    # check if (x, y) is in the boundary of roi
+    for i in range(4):
+        xs = roi_x[i]
+        ys = roi_y[i]
+        xe = roi_x[(i + 1) % 4]
+        ye = roi_y[(i + 1) % 4]
+        if abs(ys - ye) < 1e-4:
+            if abs(y - ys) < 1e-4 and abs(y - ye) < 1e-4 and gt_e(
+                    x, min(xs, xe)) and lt_e(x, max(xs, xe)):
+                return True
+        else:
+            intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs
+            if abs(intersec_x - x) < 1e-4 and gt_e(y, min(ys, ye)) and lt_e(
+                    y, max(ys, ye)):
+                return True
+    n_cross = 0
+    for i in range(4):
+        xs = roi_x[i]
+        ys = roi_y[i]
+        xe = roi_x[(i + 1) % 4]
+        ye = roi_y[(i + 1) % 4]
+        if abs(ys - ye) < 1e-4:
+            continue
+        if lt_e(y, min(ys, ye)) or gt(y, max(ys, ye)):
+            continue
+        intersec_x = (y - ys) * (xe - xs) / (ye - ys) + xs
+        if abs(intersec_x - x) < 1e-4:
+            return True
+        if gt(intersec_x, x):
+            n_cross += 1
+    return (n_cross % 2 == 1)
+
+
+def get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y):
+    x0 = roi_x[0]
+    x1 = roi_x[1]
+    x2 = roi_x[2]
+    x3 = roi_x[3]
+    y0 = roi_y[0]
+    y1 = roi_y[1]
+    y2 = roi_y[2]
+    y3 = roi_y[3]
+
+    len1 = sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1))
+    len2 = sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))
+    len3 = sqrt((x2 - x3) * (x2 - x3) + (y2 - y3) * (y2 - y3))
+    len4 = sqrt((x3 - x0) * (x3 - x0) + (y3 - y0) * (y3 - y0))
+    estimated_height = (len2 + len4) / 2.0
+    estimated_width = (len1 + len3) / 2.0
+
+    normalized_height = transformed_height
+    normalized_width = round(estimated_width *
+                             (normalized_height - 1) / estimated_height) + 1
+    normalized_width = min(normalized_width, transformed_width)
+
+    dx1 = x1 - x2
+    dx2 = x3 - x2
+    dx3 = x0 - x1 + x2 - x3
+    dy1 = y1 - y2
+    dy2 = y3 - y2
+    dy3 = y0 - y1 + y2 - y3
+    matrix = np.zeros([9])
+    matrix[6] = (dx3 * dy2 - dx2 * dy3) / (dx1 * dy2 - dx2 * dy1) / (
+        normalized_width - 1)
+    matrix[7] = (dx1 * dy3 - dx3 * dy1) / (dx1 * dy2 - dx2 * dy1) / (
+        normalized_height - 1)
+    matrix[8] = 1
+
+    matrix[3] = (y1 - y0 + matrix[6] *
+                 (normalized_width - 1) * y1) / (normalized_width - 1)
+    matrix[4] = (y3 - y0 + matrix[7] *
+                 (normalized_height - 1) * y3) / (normalized_height - 1)
+    matrix[5] = y0
+
+    matrix[0] = (x1 - x0 + matrix[6] *
+                 (normalized_width - 1) * x1) / (normalized_width - 1)
+    matrix[1] = (x3 - x0 + matrix[7] *
+                 (normalized_height - 1) * x3) / (normalized_height - 1)
+    matrix[2] = x0
+    return matrix
+
+
+def get_source_coords(matrix, out_w, out_h):
+    u = matrix[0] * out_w + matrix[1] * out_h + matrix[2]
+    v = matrix[3] * out_w + matrix[4] * out_h + matrix[5]
+    w = matrix[6] * out_w + matrix[7] * out_h + matrix[8]
+    in_w = u / w
+    in_h = v / w
+    return in_w, in_h
+
+
+def bilinear_interpolate(in_data, in_n, in_c, in_w, in_h):
+
+    batch_size = in_data.shape[0]
+    channels = in_data.shape[1]
+    height = in_data.shape[2]
+    width = in_data.shape[3]
+
+    if gt(-0.5, in_w) or gt(in_w, width - 0.5) or gt(-0.5, in_h) or gt(
+            in_h, height - 0.5):
+        return 0.0
+
+    if gt(0, in_w):
+        in_w = 0
+    if gt(0, in_h):
+        in_h = 0
+
+    in_w_floor = floor(in_w)
+    in_h_floor = floor(in_h)
+
+    if gt_e(in_w_floor, width - 1):
+        in_w_ceil = width - 1
+        in_w_floor = width - 1
+        in_w = in_w_floor
+    else:
+        in_w_ceil = in_w_floor + 1
+
+    if gt_e(in_h_floor, height - 1):
+        in_h_ceil = height - 1
+        in_h_floor = height - 1
+        in_h = in_h_floor
+    else:
+        in_h_ceil = in_h_floor + 1
+
+    w_floor = in_w - in_w_floor
+    h_floor = in_h - in_h_floor
+    w_ceil = 1 - w_floor
+    h_ceil = 1 - h_floor
+    v1 = in_data[in_n][in_c][int(in_h_floor)][int(in_w_floor)]
+    v2 = in_data[in_n][in_c][int(in_h_ceil)][int(in_w_floor)]
+    v3 = in_data[in_n][in_c][int(in_h_ceil)][int(in_w_ceil)]
+    v4 = in_data[in_n][in_c][int(in_h_floor)][int(in_w_ceil)]
+    w1 = w_ceil * h_ceil
+    w2 = w_ceil * h_floor
+    w3 = w_floor * h_floor
+    w4 = w_floor * h_ceil
+    val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
+    return val
+
+
+def lod_convert(lod):
+    ret = [0]
+    for count in lod:
+        ret.append(ret[-1] + count)
+    return ret
+
+
+def roi_transform(in_data, rois, rois_lod, transformed_height,
+                  transformed_width, spatial_scale):
+    channels = in_data.shape[1]
+    in_height = in_data.shape[2]
+    in_width = in_data.shape[3]
+    rois_num = rois.shape[0]
+
+    roi2image = [0] * rois_num
+    rois_lod = lod_convert(rois_lod[0])
+    for i in range(len(rois_lod) - 1):
+        for j in range(rois_lod[i], rois_lod[i + 1]):
+            roi2image[j] = i
+
+    out = np.zeros([rois_num, channels, transformed_height, transformed_width])
+
+    for n in range(rois_num):
+        roi_x = []
+        roi_y = []
+        for k in range(4):
+            roi_x.append(rois[n][2 * k] * spatial_scale)
+            roi_y.append(rois[n][2 * k + 1] * spatial_scale)
+        image_id = roi2image[n]
+        transform_matrix = get_transform_matrix(
+            transformed_width, transformed_height, roi_x, roi_y)
+
+        for c in range(channels):
+            for out_h in range(transformed_height):
+                for out_w in range(transformed_width):
+                    in_w, in_h = get_source_coords(transform_matrix, out_w,
+                                                   out_h)
+                    if in_quad(in_w, in_h, roi_x, roi_y) and gt_e(
+                            in_w, -0.5) and lt_e(in_w, in_width - 0.5) and gt_e(
+                                in_h, -0.5) and lt_e(in_h, in_height - 0.5):
+                        out[n][c][out_h][out_w] = bilinear_interpolate(
+                            in_data, image_id, c, in_w, in_h)
+                    else:
+                        out[n][c][out_h][out_w] = 0.0
+    return out.astype("float32")
+
+
+class TestROIPoolOp(OpTest):
+    def set_data(self):
+        self.init_test_case()
+        self.make_rois()
+
+        self.inputs = {'X': self.x, 'ROIs': (self.rois, self.rois_lod)}
+
+        self.attrs = {
+            'spatial_scale': self.spatial_scale,
+            'transformed_height': self.transformed_height,
+            'transformed_width': self.transformed_width
+        }
+        out = roi_transform(self.x, self.rois, self.rois_lod,
+                            self.transformed_height, self.transformed_width,
+                            self.spatial_scale)
+        self.outputs = {'Out': out}
+
+    def init_test_case(self):
+        self.batch_size = 2
+        self.channels = 2
+        self.height = 8
+        self.width = 8
+
+        # n, c, h, w
+        self.x_dim = (self.batch_size, self.channels, self.height, self.width)
+
+        self.spatial_scale = 1.0 / 2.0
+        self.transformed_height = 2
+        self.transformed_width = 3
+
+        self.x = np.random.random(self.x_dim).astype('float32')
+
+    def make_rois(self):
+        rois = []
+        self.rois_lod = [[]]
+        for bno in range(self.batch_size):
+            self.rois_lod[0].append(bno + 1)
+            for i in range(bno + 1):
+                x1 = np.random.randint(
+                    0,
+                    self.width // self.spatial_scale - self.transformed_width)
+                y1 = np.random.randint(
+                    0,
+                    self.height // self.spatial_scale - self.transformed_height)
+
+                x2 = np.random.randint(x1 + self.transformed_width,
+                                       self.width // self.spatial_scale)
+                y2 = np.random.randint(
+                    0,
+                    self.height // self.spatial_scale - self.transformed_height)
+
+                x3 = np.random.randint(x1 + self.transformed_width,
+                                       self.width // self.spatial_scale)
+                y3 = np.random.randint(y1 + self.transformed_height,
+                                       self.height // self.spatial_scale)
+
+                x4 = np.random.randint(
+                    0,
+                    self.width // self.spatial_scale - self.transformed_width)
+                y4 = np.random.randint(y1 + self.transformed_height,
+                                       self.height // self.spatial_scale)
+
+                roi = [x1, y1, x2, y2, x3, y3, x4, y4]
+                rois.append(roi)
+        self.rois_num = len(rois)
+        self.rois = np.array(rois).astype("float32")
+
+    def setUp(self):
+        self.op_type = "roi_perspective_transform"
+        self.set_data()
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(['X'], 'Out')
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py
index 200175cfe8..59899e7e9a 100644
--- a/python/paddle/fluid/transpiler/details/program_utils.py
+++ b/python/paddle/fluid/transpiler/details/program_utils.py
@@ -21,13 +21,12 @@ import paddle
 
 
 def delete_ops(block, ops):
-    try:
-        start = list(block.ops).index(ops[0])
-        end = list(block.ops).index(ops[-1])
-        [block._remove_op(start) for _ in six.moves.range(end - start + 1)]
-    except Exception as e:
-        raise e
-    block.program._sync_with_cpp()
+    for op in ops:
+        try:
+            idx = list(block.ops).index(op)
+            block._remove_op(idx)
+        except Exception as e:
+            print(e)
 
 
 def find_op_by_input_arg(block, arg_name):
@@ -37,10 +36,18 @@ def find_op_by_input_arg(block, arg_name):
     return -1
 
 
-def find_op_by_output_arg(block, arg_name):
-    for index, op in enumerate(block.ops):
-        if arg_name in op.output_arg_names:
-            return index
+def find_op_by_output_arg(block, arg_name, reverse=False):
+    if reverse:
+        pos = len(block.ops) - 1
+        while pos >= 0:
+            op = block.ops[pos]
+            if arg_name in op.output_arg_names:
+                return pos
+            pos -= 1
+    else:
+        for index, op in enumerate(block.ops):
+            if arg_name in op.output_arg_names:
+                return index
     return -1
 
 
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index f58f1883a4..3f8c7b844a 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -50,6 +50,15 @@ OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
 RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
 )
 RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
+DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
+LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
+
+PRINT_LOG = False
+
+
+def log(*args):
+    if PRINT_LOG:
+        print(args)
 
 
 class VarBlock:
@@ -127,6 +136,7 @@ class DistributeTranspilerConfig(object):
     slice_var_up = True
     split_method = None
     min_block_size = 8192
+    print_log = False
 
 
 class DistributeTranspiler(object):
@@ -174,6 +184,9 @@ class DistributeTranspiler(object):
         if self.config.split_method is None:
             self.config.split_method = RoundRobin
 
+        global PRINT_LOG
+        if self.config.print_log:
+            PRINT_LOG = True
         assert (self.config.min_block_size >= 8192)
         assert (self.config.split_method.__bases__[0] == PSDispatcher)
 
@@ -257,12 +270,12 @@ class DistributeTranspiler(object):
             splited_grad_varname = grad_varname
             if len(splited_vars) == 1:
                 splited_grad_varname = splited_vars[0].name
-                index = find_op_by_output_arg(program.global_block(),
-                                              splited_grad_varname)
+                index = find_op_by_output_arg(
+                    program.global_block(), splited_grad_varname, reverse=True)
             elif len(splited_vars) > 1:
                 orig_var = program.global_block().vars[splited_grad_varname]
-                index = find_op_by_output_arg(program.global_block(),
-                                              splited_grad_varname)
+                index = find_op_by_output_arg(
+                    program.global_block(), splited_grad_varname, reverse=True)
                 self._insert_split_op(program, orig_var, index, splited_vars)
                 index += 1
             else:
@@ -301,7 +314,7 @@ class DistributeTranspiler(object):
                 self.grad_name_to_send_dummy_out[
                     self.table_name] = program.global_block().create_var(
                         name=framework.generate_control_dev_var_name())
-            input_deps = self.grad_name_to_send_dummy_out.values()
+            input_deps = list(self.grad_name_to_send_dummy_out.values())
 
             program.global_block().append_op(
                 type="send_barrier",
@@ -377,7 +390,10 @@ class DistributeTranspiler(object):
                 type="concat",
                 inputs={"X": splited_var},
                 outputs={"Out": [orig_param]},
-                attrs={"axis": 0})
+                attrs={
+                    "axis": 0,
+                    RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
+                })
 
         self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
 
@@ -496,9 +512,9 @@ class DistributeTranspiler(object):
         # NOTE: assume blocks of the same variable is not distributed
         # on the same pserver, only change param/grad varnames for
         # trainers to fetch.
-        sys.stderr.write("get_pserver_program() is deprecated, call\
-            get_pserver_programs() to get pserver main and startup\
-            in a single call.")
+        sys.stderr.write("get_pserver_program() is deprecated, call \
+get_pserver_programs() to get pserver main and startup \
+in a single call.")
         # step1
         pserver_program = Program()
         pserver_program.random_seed = self.origin_program.random_seed
@@ -615,22 +631,31 @@ class DistributeTranspiler(object):
         for idx, opt_op in enumerate(opt_op_on_pserver):
             per_opt_block = pserver_program._create_block(pre_block_idx)
             optimize_blocks.append(per_opt_block)
+            optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
             # append grad merging ops before clip and weight decay
-            # cases may like:
-            # L2Decay op -> clip op -> optimize
+            # e.g. merge grad -> L2Decay op -> clip op -> optimize
+            merged_var = None
             for _, op in enumerate(self.optimize_ops):
-                # find the origin @GRAD var before clipping
-                grad_varname_for_block = __op_have_grad_input__(op)
-                if ufind.is_connected(op, opt_op) and grad_varname_for_block:
+                # find the origin grad var before clipping/L2Decay,
+                # merged_var should be the input var name of L2Decaybuil
+                grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
+                if op.attr(OP_ROLE_VAR_ATTR_NAME)[
+                        0] == optimize_target_param_name:
                     merged_var = self._append_pserver_grad_merge_ops(
                         per_opt_block, grad_varname_for_block, endpoint,
                         grad_to_block_id, self.origin_program)
-                    break  # append optimize op once then append other ops.
-            for _, op in enumerate(self.optimize_ops):
-                # optimizer is connected to itself
-                if ufind.is_connected(op, opt_op) and op not in global_ops:
-                    __append_optimize_op__(op, per_opt_block, grad_to_block_id,
-                                           merged_var, lr_ops)
+                    if merged_var:
+                        break  # append optimize op once then append other ops.
+            if merged_var:
+                for _, op in enumerate(self.optimize_ops):
+                    # optimizer is connected to itself
+                    if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \
+                        op not in global_ops:
+                        log("append opt op: ", op.type, op.input_arg_names,
+                            merged_var)
+                        __append_optimize_op__(op, per_opt_block,
+                                               grad_to_block_id, merged_var,
+                                               lr_ops)
 
         # dedup grad to ids list
         grad_to_block_id = list(set(grad_to_block_id))
@@ -726,17 +751,17 @@ class DistributeTranspiler(object):
         Returns:
             Program: parameter server side startup program.
         """
-        sys.stderr.write("get_startup_program() is deprecated, call\
-            get_pserver_programs() to get pserver main and startup\
-            in a single call.")
+        sys.stderr.write("get_startup_program() is deprecated, call \
+get_pserver_programs() to get pserver main and startup \
+in a single call.")
         if pserver_program != None:
-            sys.stderr.write("passing pserver_program to get_startup_program()\
-                is deprecated, you can use new API get_pserver_programs() to\
-                get both pserver main program and startup program.")
+            sys.stderr.write("passing pserver_program to get_startup_program() \
+is deprecated, you can use new API get_pserver_programs() to \
+get both pserver main program and startup program.")
         if startup_program != None:
-            sys.stderr.write("passing startup_program to get_startup_program()\
-                is deprecated, use fluid.program_guard() or pass this argument\
-                to transpile() call.")
+            sys.stderr.write("passing startup_program to get_startup_program() \
+is deprecated, use fluid.program_guard() or pass this argument \
+to transpile() call.")
 
         s_prog = Program()
         orig_s_prog = self.startup_program
@@ -1302,7 +1327,10 @@ class DistributeTranspiler(object):
                 type="split_selected_rows",
                 inputs={"X": orig_var},
                 outputs={"Out": splited_vars},
-                attrs={"height_sections": height_sections})
+                attrs={
+                    "height_sections": height_sections,
+                    RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
+                })
         elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
             sections = []
             for v in splited_vars:
@@ -1312,8 +1340,10 @@ class DistributeTranspiler(object):
                 type="split_byref",
                 inputs={"X": orig_var},
                 outputs={"Out": splited_vars},
-                attrs={"sections": sections}  # assume split evenly
-            )
+                attrs={
+                    "sections": sections,
+                    RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
+                })
         else:
             AssertionError("Variable type should be in set "
                            "[LOD_TENSOR, SELECTED_ROWS]")
@@ -1381,15 +1411,15 @@ class DistributeTranspiler(object):
         if not grad_block:
             # do not append this op if current endpoint
             # is not dealing with this grad block
-            return
+            return None
         orig_varname, block_name, trainer_name = self._get_varname_parts(
             grad_block.name)
         if block_name:
             merged_var_name = '.'.join([orig_varname, block_name])
         else:
             merged_var_name = orig_varname
-        merged_var = \
-            pserver_block.vars[merged_var_name]
+
+        merged_var = pserver_block.vars[merged_var_name]
         grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx))
         if self.sync_mode and self.trainer_num > 1:
             vars2merge = []
@@ -1473,7 +1503,6 @@ class DistributeTranspiler(object):
         outputs = self._get_output_map_from_op(
             self.origin_program.global_block().vars, opt_op)
         outputs["ParamOut"] = new_inputs["Param"]
-
         optimize_block.append_op(
             type=opt_op.type,
             inputs=new_inputs,
@@ -1618,6 +1647,16 @@ class DistributeTranspiler(object):
         return iomap
 
     def _get_lr_ops(self):
+        lr_ops = []
+        block = self.origin_program.global_block()
+        for op in block.ops:
+            if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int(
+                    LR_SCHED_OP_ROLE_ATTR_VALUE):
+                lr_ops.append(op)
+                log("append lr op: ", op.type)
+        return lr_ops
+
+    def _get_lr_ops_deprecated(self):
         lr_ops = []
         # find learning rate variables by optimize op
         lr_vars = set()
@@ -1670,20 +1709,21 @@ class DistributeTranspiler(object):
         block = self.origin_program.global_block()
         opt_ops = []
         params_grads = []
+        # tmp set to dedup
+        optimize_params = set()
         origin_var_dict = self.origin_program.global_block().vars
         for op in block.ops:
             if self._is_opt_role_op(op):
                 opt_ops.append(op)
-                # HACK(wuyi): if we find grad vars from input of optimize
-                # ops, we may get the output of clip op. Use syntax "@GRAD"
-                # and op_role_var to get the pair.
-                for input_name in op.input_arg_names:
-                    if input_name.find("@GRAD") != -1 and \
-                        op.attr(RPC_OP_ROLE_ATTR_NAME):
-                        param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
+                if op.attr(OP_ROLE_VAR_ATTR_NAME):
+                    param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0]
+                    grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
+                    if not param_name in optimize_params:
+                        optimize_params.add(param_name)
+                        log("adding param_grad pair: ", param_name, grad_name)
                         params_grads.append([
                             origin_var_dict[param_name],
-                            origin_var_dict[input_name]
+                            origin_var_dict[grad_name]
                         ])
             else:
                 pass
diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py
index d4517059a4..d5aa54d752 100755
--- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py
+++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py
@@ -14,10 +14,10 @@
 
 from __future__ import print_function
 
-from collections import defaultdict
+from collections import defaultdict, OrderedDict, Callable
 from .. import core
 from ... import compat as cpt
-from ..framework import Program, default_main_program, Parameter
+from ..framework import Program, default_main_program, Parameter, Variable
 from ..backward import _rename_arg_
 from functools import reduce
 from six.moves import range
@@ -113,8 +113,10 @@ class ControlFlowGraph(object):
     def _fill_pool(self, i, is_forward):
         block_desc = self._ops[i].block()
         in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
+        # NOTE: must sort the in_diff set for cases that get different cache var.
+        # FIXME(typhoonzero): maybe use a "sorted set" is better than this.
         can_optimize = [
-            x for x in in_diff
+            x for x in sorted(list(in_diff))
             if self._check_var_validity(block_desc, x, is_forward)
         ]
         if can_optimize:
@@ -220,8 +222,9 @@ class ControlFlowGraph(object):
             block_desc = op.block()
             is_forward = i < self._forward_num
             if self.pool:
+                # NOTE: must sort the in_diff set for cases that get different cache var.
                 defs_can_optimize = [
-                    x for x in self._defs[i]
+                    x for x in sorted(list(self._defs[i]))
                     if self._check_var_validity(block_desc, x, is_forward)
                 ]
                 out_pair = [
@@ -271,6 +274,8 @@ class ControlFlowGraph(object):
                         self._program.block(block_desc.id).var(cpt.to_text(
                             x)).desc = self._find_var(block_desc, cache_var,
                                                       is_forward)
+                        self._program.block(block_desc.id).vars[cpt.to_text(x)] = \
+                            Variable(self._program.block(block_desc.id), name=cpt.to_text(x))
                         self._update_graph(x, cache_var, begin_idx=i)
                         break
             self._fill_pool(i, is_forward)