diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 66d4aee09a..c9886cd118 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -160,6 +160,12 @@ paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None
 paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
 paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.layers.elu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
+paddle.fluid.layers.relu6 ArgSpec(args=['x', 'threshold', 'name'], varargs=None, keywords=None, defaults=(6.0, None))
+paddle.fluid.layers.pow ArgSpec(args=['x', 'factor', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
+paddle.fluid.layers.stanh ArgSpec(args=['x', 'scale_a', 'scale_b', 'name'], varargs=None, keywords=None, defaults=(0.6666666666666666, 1.7159, None))
+paddle.fluid.layers.hard_sigmoid ArgSpec(args=['x', 'slope', 'offset', 'name'], varargs=None, keywords=None, defaults=(0.2, 0.5, None))
+paddle.fluid.layers.swish ArgSpec(args=['x', 'beta', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
 paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
 paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
 paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
@@ -276,12 +282,6 @@ paddle.fluid.layers.softsign ArgSpec(args=[], varargs='args', keywords='kwargs',
 paddle.fluid.layers.brelu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
 paddle.fluid.layers.leaky_relu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
 paddle.fluid.layers.soft_relu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
-paddle.fluid.layers.elu ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
-paddle.fluid.layers.relu6 ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
-paddle.fluid.layers.pow ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
-paddle.fluid.layers.stanh ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
-paddle.fluid.layers.hard_sigmoid ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
-paddle.fluid.layers.swish ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
 paddle.fluid.layers.uniform_random ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None))
 paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None))
diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h
index 71db8d952f..fc479a4c4a 100644
--- a/paddle/fluid/framework/details/reference_count_op_handle.h
+++ b/paddle/fluid/framework/details/reference_count_op_handle.h
@@ -22,6 +22,7 @@
 #include "paddle/fluid/framework/details/op_handle_base.h"
 #include "paddle/fluid/framework/garbage_collector.h"
 #include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/framework/selected_rows.h"
 #include "paddle/fluid/framework/tensor.h"
 
 namespace paddle {
@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
                          const std::vector<std::string> &var_names,
                          GarbageCollector<Tensor> *gc,
                          AtomicReferenceCountMap *ref_cnts)
-      : OpHandleBase(node),
-        scope_(scope),
-        var_names_(var_names),
-        gc_(gc),
-        ref_cnts_(ref_cnts) {
+      : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
     dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
         platform::DeviceContextPool::Instance().Get(place));
     if (IsStreamGarabageCollector()) {
       PADDLE_ENFORCE(cudaSetDevice(place.device));
       PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
     }
+
+    for (auto &name : var_names) AddVar(name);
   }
 
   ~ReferenceCountOpHandle() {
@@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase {
 
   std::string Name() const override { return "reference_count"; }
 
+  void AddVar(const std::string &name) {
+    auto it = var_names_.find(name);
+    if (it != var_names_.end())
+      ++(it->second);
+    else
+      var_names_[name] = 1;
+  }
+
  protected:
   void RunImpl() override {
     auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
-    std::vector<LoDTensor *> tensors;
-    for (auto &name : var_names_) {
+    std::vector<Tensor *> tensors;
+    for (auto &pair : var_names_) {
+      auto &name = pair.first;
       auto it = ref_cnts_->find(name);
       if (it == ref_cnts_->end()) continue;
 
       auto *var = exec_scope->FindVar(name);
-      if (var == nullptr || !var->IsType<LoDTensor>()) continue;
-
-      if (it->second.fetch_sub(1) <= 1) {
-        tensors.emplace_back(var->GetMutable<LoDTensor>());
+      if (var == nullptr) continue;
+
+      if (var->IsType<LoDTensor>()) {
+        if (it->second.fetch_sub(pair.second) <= pair.second) {
+          tensors.emplace_back(var->GetMutable<LoDTensor>());
+        }
+      } else if (var->IsType<SelectedRows>()) {
+        if (it->second.fetch_sub(pair.second) <= pair.second) {
+          tensors.emplace_back(
+              var->GetMutable<SelectedRows>()->mutable_value());
+        }
       }
     }
 
@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
   }
 
  private:
-  void ClearTensors(const std::vector<LoDTensor *> &tensors) {
+  void ClearTensors(const std::vector<Tensor *> &tensors) {
     auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
     if (gc != nullptr) {
       auto compute_stream = dev_ctx_->stream();
@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
 
   const Scope *scope_;
   platform::CUDADeviceContext *dev_ctx_;
-  std::vector<std::string> var_names_;
+  std::unordered_map<std::string, int> var_names_;
   GarbageCollector<Tensor> *gc_;       // not own
   AtomicReferenceCountMap *ref_cnts_;  // not own
   cudaEvent_t event_;
diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc
index 344754d5a1..b1ce551ce7 100644
--- a/paddle/fluid/framework/details/reference_count_pass.cc
+++ b/paddle/fluid/framework/details/reference_count_pass.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <queue>
 #include <string>
 #include <vector>
 
@@ -23,6 +24,25 @@ namespace paddle {
 namespace framework {
 namespace details {
 
+static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
+  std::queue<VarHandleBase *> queue;
+  queue.push(var_in);
+  do {
+    auto *var = queue.front();
+    queue.pop();
+    for (auto *op : var->PendingOps()) {
+      auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
+      if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
+        return compute_op;
+      }
+      for (auto *out_var : op->Outputs()) {
+        queue.push(out_var);
+      }
+    }
+  } while (!queue.empty());
+  return nullptr;
+}
+
 std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
   // Step 2: Find all variables in non-computation ops which refers to variables
   // in computation ops
   std::unordered_set<std::string> names;
+  std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
+      compute_ref_cnt_map;
+
   auto get_ref_cnts_from_compute_op = [&](
       const std::unique_ptr<OpHandleBase> &op,
       const std::vector<VarHandleBase *> &vars) {
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
       VarDesc *var_desc = var_handle->Node()->Var();
       auto var_name = var_handle->Node()->Name();
 
-      // This is wierd but there is really some variables without var_desc
+      // This is weird but there is really some variables without var_desc
       // in computation_op
       if (var_desc == nullptr) {
         if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
           continue;
       } else {
-        if (var_desc->Persistable() ||
-            var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
+        if (var_desc->Persistable()) continue;
+        auto var_type = var_desc->Proto()->type().type();
+        if (var_type != proto::VarType::LOD_TENSOR &&
+            var_type != proto::VarType::SELECTED_ROWS) {
           continue;
+        }
       }
 
       // compute op only runs in one device
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
       if (ref_cnts.count(place.device) &&
           ref_cnts[place.device]->count(var_name)) {
         ++(*ref_cnts[place.device])[var_name];
+
+        auto *next_compute_op = FindNextComputationOpHandle(var_handle);
+        if (next_compute_op != nullptr) {
+          if (compute_ref_cnt_map.count(next_compute_op)) {
+            compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
+            VLOG(5) << "Add reference count of " << var_name << " to Operator "
+                    << next_compute_op->Name();
+          } else {
+            // Create new reference_count_op_handle
+            ir::Node *ref_cnt_node = graph->CreateEmptyNode(
+                "reference_count", ir::Node::Type::kOperation);
+            auto *ref_cnt_handle = new ReferenceCountOpHandle(
+                ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
+                gcs[place.device].get(), cur_ref_cnts[place.device].get());
+            if (next_compute_op->Outputs().empty()) {
+              auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
+              next_compute_op->AddOutput(dep_var);
+              graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
+            }
+            ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
+            compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
+          }
+        }
       }
     }
   };
 
-  std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
-      compute_ref_cnt_map;
   auto &all_ops = graph->Get<GraphOps>(kGraphOps);
   for (auto &op : all_ops) {
     auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
     auto *ref_cnt_handle = new ReferenceCountOpHandle(
         ref_cnt_node, compute_op->GetScope(), place, in_var_names,
         gcs[place.device].get(), cur_ref_cnts[place.device].get());
-    auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
-    compute_op->AddOutput(dep_var);
-    ref_cnt_handle->AddInput(dep_var);
-    graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
-    compute_ref_cnt_map[compute_op] = ref_cnt_handle;
+    if (compute_op->Outputs().empty()) {
+      auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
+      compute_op->AddOutput(dep_var);
+      graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
+    }
+    ref_cnt_handle->AddInput(compute_op->Outputs().front());
+    compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
   }
 
   for (auto &op : all_ops) {
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
     new_all_ops.emplace_back(std::move(op));
     auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
     if (it != compute_ref_cnt_map.end()) {
-      new_all_ops.emplace_back(it->second);
+      // Add LeafNode to ReferenceCountOpHandle
+      auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
+      graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
+      it->second->AddOutput(dummy_leaf);
+      new_all_ops.emplace_back(std::move(it->second));
     }
   }
 
diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h
index 5b27068c9e..4cb1f3a80e 100644
--- a/paddle/fluid/operators/adam_op.h
+++ b/paddle/fluid/operators/adam_op.h
@@ -15,6 +15,7 @@ limitations under the License. */
 #pragma once
 #include <math.h>  // for sqrt in CPU and CUDA
 #include <Eigen/Dense>
+#include <vector>
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/operators/detail/safe_ref.h"
 #include "paddle/fluid/operators/math/selected_rows_functor.h"
@@ -306,26 +307,43 @@ class AdamOpKernel : public framework::OpKernel<T> {
         VLOG(3) << "grad row size is 0!!";
         return;
       }
-      // merge duplicated rows if any.
-      // The rows of grad_merge have been sorted inside MergeAdd functor
-      scatter::MergeAdd<DeviceContext, T> merge_func;
-      auto& grad_merge = *(ctx.scope()
-                               .NewScope()
-                               .Var("sparse_adam_grad_merge")
-                               ->GetMutable<framework::SelectedRows>());
-      merge_func(ctx.template device_context<DeviceContext>(), grad,
-                 &grad_merge);
+
+      std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
+      bool is_strict_sorted = true;
+      for (size_t i = 1; i < cpu_rows.size(); ++i) {
+        if (cpu_rows[i - 1] >= cpu_rows[i]) {
+          is_strict_sorted = false;
+          break;
+        }
+      }
+
+      const framework::SelectedRows* grad_merge_ptr;
+      if (is_strict_sorted) {
+        grad_merge_ptr = &grad;
+      } else {
+        // merge duplicated rows if any.
+        // The rows of grad_merge have been sorted inside MergeAdd functor
+        scatter::MergeAdd<DeviceContext, T> merge_func;
+        auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
+                                   .Var()
+                                   ->GetMutable<framework::SelectedRows>();
+        merge_func(ctx.template device_context<DeviceContext>(), grad,
+                   grad_merge_var);
+        grad_merge_ptr = grad_merge_var;
+      }
+
+      auto& grad_merge = *grad_merge_ptr;
       auto& grad_tensor = grad_merge.value();
       const T* grad_data = grad_tensor.template data<T>();
-      int64_t* rows = nullptr;
-// When compiled without CUDA, the CUDAMutableData() interface should not be
+      const int64_t* rows = nullptr;
+// When compiled without CUDA, the CUDAData() interface should not be
 // provided.
 #if defined(PADDLE_WITH_CUDA)
       if (platform::is_gpu_place(ctx.GetPlace())) {
-        rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace());
+        rows = grad_merge.rows().CUDAData(ctx.GetPlace());
       } else {
 #endif
-        rows = grad_merge.mutable_rows()->data();
+        rows = grad_merge.rows().data();
 
 #if defined(PADDLE_WITH_CUDA)
       }
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index 058d939464..58384aa0f5 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -45,8 +45,9 @@ __all__ = [
     'lod_reset', 'lrn', 'pad', 'pad_constant_like', 'label_smooth', 'roi_pool',
     'dice_loss', 'image_resize', 'image_resize_short', 'resize_bilinear',
     'gather', 'scatter', 'sequence_scatter', 'random_crop', 'mean_iou', 'relu',
-    'log', 'crop', 'rank_loss', 'prelu', 'flatten', 'sequence_mask', 'stack',
-    'pad2d', 'unstack', 'sequence_enumerate', 'expand', 'sequence_concat',
+    'log', 'crop', 'rank_loss', 'elu', 'relu6', 'pow', 'stanh', 'hard_sigmoid',
+    'swish', 'prelu', 'flatten', 'sequence_mask', 'stack', 'pad2d', 'unstack',
+    'sequence_enumerate', 'expand', 'sequence_concat',
     'uniform_random_batch_size_like', 'gaussian_random', 'sampling_id',
     'gaussian_random_batch_size_like', 'sum', 'slice', 'shape'
 ]
@@ -5828,6 +5829,148 @@ def pad2d(input,
     return out
 
 
+@templatedoc()
+def elu(x, alpha=1.0, name=None):
+    """
+    ${comment}
+    Args:
+        x(${x_type}): ${x_comment}
+        alpha(${alpha_type}|1.0): ${alpha_comment}
+        name(str|None): A name for this layer(optional). If set None, the layer
+                        will be named automatically.
+
+    Returns:
+        output(${out_type}): ${out_comment}
+    """
+    helper = LayerHelper('elu', **locals())
+    out = helper.create_tmp_variable(dtype=x.dtype)
+    helper.append_op(
+        type='elu',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={'alpha': alpha})
+    return out
+
+
+@templatedoc()
+def relu6(x, threshold=6.0, name=None):
+    """
+    ${comment}
+    Args:
+        x(${x_type}): ${x_comment}
+        threshold(${threshold_type}|6.0): ${threshold_comment}
+        name(str|None): A name for this layer(optional). If set None, the layer
+                        will be named automatically.
+
+    Returns:
+        output(${out_type}): ${out_comment}
+    """
+    helper = LayerHelper('relu6', **locals())
+    out = helper.create_tmp_variable(dtype=x.dtype)
+    helper.append_op(
+        type='relu6',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={'threshold': threshold})
+    return out
+
+
+@templatedoc()
+def pow(x, factor=1.0, name=None):
+    """
+    ${comment}
+    Args:
+        x(${x_type}): ${x_comment}
+        factor(${factor_type}|1.0): ${factor_comment}
+        name(str|None): A name for this layer(optional). If set None, the layer
+                        will be named automatically.
+
+    Returns:
+        output(${out_type}): ${out_comment}
+    """
+    helper = LayerHelper('pow', **locals())
+    out = helper.create_tmp_variable(dtype=x.dtype)
+    helper.append_op(
+        type='pow',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={'factor': factor})
+    return out
+
+
+@templatedoc()
+def stanh(x, scale_a=2.0 / 3.0, scale_b=1.7159, name=None):
+    """
+    ${comment}
+    Args:
+        x(${x_type}): ${x_comment}
+        scale_a(${scale_a_type}|2.0 / 3.0): ${scale_a_comment}
+        scale_b(${scale_b_type}|1.7159): ${scale_b_comment}
+        name(str|None): A name for this layer(optional). If set None, the layer
+                        will be named automatically.
+
+    Returns:
+        output(${out_type}): ${out_comment}
+    """
+    helper = LayerHelper('stanh', **locals())
+    out = helper.create_tmp_variable(dtype=x.dtype)
+    helper.append_op(
+        type='stanh',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={'scale_a': scale_a,
+               'scale_b': scale_b})
+    return out
+
+
+@templatedoc()
+def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
+    """
+    ${comment}
+    Args:
+        x(${x_type}): ${x_comment}
+        slope(${slope_type}|0.2): ${slope_comment}
+        offset(${offset_type}|0.5): ${offset_comment}
+        name(str|None): A name for this layer(optional). If set None, the layer
+                        will be named automatically.
+
+    Returns:
+        output(${out_type}): ${out_comment}
+    """
+    helper = LayerHelper('hard_sigmoid', **locals())
+    out = helper.create_tmp_variable(dtype=x.dtype)
+    helper.append_op(
+        type='hard_sigmoid',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={'slope': slope,
+               'offset': offset})
+    return out
+
+
+@templatedoc()
+def swish(x, beta=1.0, name=None):
+    """
+    ${comment}
+    Args:
+        x(${x_type}): ${x_comment}
+        beta(${beta_type}|1.0): ${beta_comment}
+        name(str|None): A name for this layer(optional). If set None, the layer
+                        will be named automatically.
+
+    Returns:
+        output(${out_type}): ${out_comment}
+    """
+    helper = LayerHelper('swish', **locals())
+    out = helper.create_tmp_variable(dtype=x.dtype)
+    helper.append_op(
+        type='swish',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={'slope': beta})
+    return out
+
+
 def prelu(x, mode, param_attr=None, name=None):
     """
     Equation:
diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py
index 5191c88274..b2ae77e867 100644
--- a/python/paddle/fluid/layers/ops.py
+++ b/python/paddle/fluid/layers/ops.py
@@ -36,12 +36,6 @@ __activations__ = [
     'brelu',
     'leaky_relu',
     'soft_relu',
-    'elu',
-    'relu6',
-    'pow',
-    'stanh',
-    'hard_sigmoid',
-    'swish',
 ]
 
 __all__ = [