diff --git a/cmake/operators.cmake b/cmake/operators.cmake
index 89726bf985..2ced43f9e6 100644
--- a/cmake/operators.cmake
+++ b/cmake/operators.cmake
@@ -166,6 +166,8 @@ function(op_library TARGET)
       # Append first implemented MKLDNN activation operator
       if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
         file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
+      elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
+        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
       else()
         file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
       endif()
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 4f286049d4..2722ea078e 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -194,6 +194,8 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
 paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
 paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
+paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
 paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
 paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
@@ -299,6 +301,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i
 paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
 paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
+paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
 paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
 paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
 paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
@@ -419,3 +422,17 @@ paddle.fluid.Scope.drop_kids drop_kids(self: paddle.fluid.core.Scope) -> None
 paddle.fluid.Scope.find_var find_var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable
 paddle.fluid.Scope.new_scope new_scope(self: paddle.fluid.core.Scope) -> paddle.fluid.core.Scope
 paddle.fluid.Scope.var var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable
+paddle.reader.map_readers ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None)
+paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None)
+paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None)
+paddle.reader.chain ArgSpec(args=[], varargs='readers', keywords=None, defaults=None)
+paddle.reader.shuffle ArgSpec(args=['reader', 'buf_size'], varargs=None, keywords=None, defaults=None)
+paddle.reader.firstn ArgSpec(args=['reader', 'n'], varargs=None, keywords=None, defaults=None)
+paddle.reader.xmap_readers ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,))
+paddle.reader.PipeReader.__init__ ArgSpec(args=['self', 'command', 'bufsize', 'file_type'], varargs=None, keywords=None, defaults=(8192, 'plain'))
+paddle.reader.PipeReader.get_line ArgSpec(args=['self', 'cut_lines', 'line_break'], varargs=None, keywords=None, defaults=(True, '\n'))
+paddle.reader.multiprocess_reader ArgSpec(args=['readers', 'use_pipe', 'queue_size'], varargs=None, keywords=None, defaults=(True, 1000))
+paddle.reader.Fake.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
+paddle.reader.creator.np_array ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
+paddle.reader.creator.text_file ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None)
+paddle.reader.creator.recordio ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,))
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index c701a2ad63..e4c471d86b 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
 cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
 
 cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
+cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
 cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
-    shape_inference data_transform lod_tensor profiler transfer_scope_cache)
+    shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
 
 cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
 
@@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
 cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
 cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
 
-cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
+cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
 cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
 
 cc_test(tuple_test SRCS tuple_test.cc )
diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc
index 291d8ffc3c..a99cf53b41 100644
--- a/paddle/fluid/framework/data_feed.cc
+++ b/paddle/fluid/framework/data_feed.cc
@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
   CheckInit();
   for (size_t i = 0; i < use_slots_.size(); ++i) {
     if (name == use_slots_[i]) {
-      if (use_slots_is_dense_[i]) {
-        feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
-      } else {
-        feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
-      }
+      feed_vec_[i] = var->GetMutable<LoDTensor>();
     }
   }
 }
@@ -301,6 +297,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
           "the data, please check if the data contains unresolvable "
           "characters.\nplease check this error line: %s",
           str);
+
       if (idx != -1) {
         (*instance)[idx].Init(all_slots_type_[i]);
         if ((*instance)[idx].GetType()[0] == 'f') {  // float
@@ -337,6 +334,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
       (*ins_vec)[i].InitOffset();
     }
   }
+
   for (size_t i = 0; i < instance.size(); ++i) {
     (*ins_vec)[i].AddIns(instance[i]);
   }
@@ -348,36 +346,25 @@ void MultiSlotDataFeed::PutToFeedVec(
     const auto& type = ins_vec[i].GetType();
     const auto& offset = ins_vec[i].GetOffset();
     int total_instance = static_cast<int>(offset.back());
+
     if (type[0] == 'f') {  // float
       const auto& feasign = ins_vec[i].GetFloatData();
-      if (feed_vec_[i].IsDense()) {
-        int size_in_each_batch = total_instance / batch_size_;
-        float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<float>(
-            {batch_size_, size_in_each_batch}, platform::CPUPlace());
-        memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
-      } else {
-        float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data<float>(
-            {total_instance, 1}, platform::CPUPlace());
-        memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
-        LoD data_lod{offset};
-        feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
-      }
+      float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
+          {total_instance, 1}, platform::CPUPlace());
+      memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
     } else if (type[0] == 'u') {  // uint64
       // no uint64_t type in paddlepaddle
       const auto& feasign = ins_vec[i].GetUint64Data();
-      if (feed_vec_[i].IsDense()) {
-        int size_in_each_batch = total_instance / batch_size_;
-        int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<int64_t>(
-            {batch_size_, size_in_each_batch}, platform::CPUPlace());
-        memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
-      } else {
-        int64_t* tensor_ptr =
-            feed_vec_[i].GetLoDTensor()->mutable_data<int64_t>(
-                {total_instance, 1}, platform::CPUPlace());
-        memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
-        LoD data_lod{offset};
-        feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
-      }
+      int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
+          {total_instance, 1}, platform::CPUPlace());
+      memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
+    }
+
+    LoD data_lod{offset};
+    feed_vec_[i]->set_lod(data_lod);
+    if (use_slots_is_dense_[i]) {
+      int dim = total_instance / batch_size_;
+      feed_vec_[i]->Resize({batch_size_, dim});
     }
   }
 }
diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h
index a7f8d1d317..7cc6919703 100644
--- a/paddle/fluid/framework/data_feed.h
+++ b/paddle/fluid/framework/data_feed.h
@@ -30,35 +30,6 @@ limitations under the License. */
 namespace paddle {
 namespace framework {
 
-// Pack Tensor type and LoDTensor type into MixTensor type, in order
-// to record either Tensor or LoDTensor information at the same time.
-class MixTensor {
- public:
-  MixTensor() {}
-  explicit MixTensor(LoDTensor* lodtensor) {
-    is_dense_ = false;
-    lodtensor_ = lodtensor;
-  }
-  explicit MixTensor(Tensor* tensor) {
-    is_dense_ = true;
-    tensor_ = tensor;
-  }
-  bool IsDense() { return is_dense_; }
-  LoDTensor* GetLoDTensor() {
-    PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr.");
-    return lodtensor_;
-  }
-  Tensor* GetTensor() {
-    PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr.");
-    return tensor_;
-  }
-
- private:
-  bool is_dense_;
-  LoDTensor* lodtensor_;
-  Tensor* tensor_;
-};
-
 // DataFeed is the base virtual class for all ohther DataFeeds.
 // It is used to read files and parse the data for subsequent trainer.
 // Example:
@@ -133,7 +104,7 @@ class DataFeed {
       use_slots_index_;  // -1: not used; >=0: the index of use_slots_
 
   // The data read by DataFeed will be stored here
-  std::vector<MixTensor> feed_vec_;
+  std::vector<LoDTensor*> feed_vec_;
 
   // the batch size defined by user
   int default_batch_size_;
diff --git a/paddle/fluid/framework/data_feed_test.cc b/paddle/fluid/framework/data_feed_test.cc
index 3974f8dbad..b3e9698715 100644
--- a/paddle/fluid/framework/data_feed_test.cc
+++ b/paddle/fluid/framework/data_feed_test.cc
@@ -152,19 +152,13 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
       const auto& multi_slot_desc = data_feed_desc.multi_slot_desc();
       std::map<std::string, const paddle::framework::LoDTensor*>
           lodtensor_targets;
-      std::map<std::string, const paddle::framework::Tensor*> tensor_targets;
       for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
         const auto& slot = multi_slot_desc.slots(i);
         if (slot.is_used()) {
           const auto& name = slot.name();
           readers[idx]->AddFeedVar(scope->Var(name), name);
-          if (slot.is_dense()) {
-            tensor_targets[name] =
-                &scope->FindVar(name)->Get<paddle::framework::Tensor>();
-          } else {
-            lodtensor_targets[name] =
-                &scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
-          }
+          lodtensor_targets[name] =
+              &scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
         }
       }
       readers[idx]->Start();
@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
           if (!slot.is_used()) {
             continue;
           }
+          const paddle::framework::LoDTensor* tens =
+              lodtensor_targets[slot.name()];
           if (slot.is_dense()) {  // dense branch
-            const paddle::framework::Tensor* tens = tensor_targets[slot.name()];
             if (slot.type() == "uint64") {
               const int64_t* data = tens->data<int64_t>();
               int batch_size = tens->dims()[0];
@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
               PADDLE_THROW("Error type in proto file.");
             }
           } else {  // sparse branch
-            const paddle::framework::LoDTensor* tens =
-                lodtensor_targets[slot.name()];
             if (slot.type() == "uint64") {
               const int64_t* data = tens->data<int64_t>();
               for (size_t i = 0; i < tens->NumElements(); ++i) {
diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index 93288936fe..2f76cb714f 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -15,14 +15,26 @@ cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_ro
 if(WITH_GPU)
     nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
             dynload_cuda variable_visitor)
-    nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
+    if(WITH_DISTRIBUTE)
+        nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
+            ddim dynload_cuda selected_rows_functor sendrecvop_grpc)
+    else()
+        nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
+            ddim dynload_cuda selected_rows_functor)
+    endif()
     nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
     nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
 
 else()
     cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
              variable_visitor)
-    cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
+    if(WITH_DISTRIBUTE)
+        cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
+            ddim selected_rows_functor sendrecvop_grpc)
+    else()
+        cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
+            ddim selected_rows_functor)
+    endif()
     cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
     cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
 endif()
diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc
index a003995ae3..e8bf53e160 100644
--- a/paddle/fluid/framework/details/all_reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc
@@ -48,7 +48,14 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
 void AllReduceOpHandle::RunImpl() {
   platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
 
+// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
+// this is a distributed or inter-process call, find a better way.
+#ifdef PADDLE_WITH_CUDA
+  if (NoDummyInputSize() == 1 &&
+      local_scopes_[0]->FindLocalVar(NCCL_ID_VARNAME) == nullptr) {
+#else
   if (NoDummyInputSize() == 1) {
+#endif
     return;  // No need to all reduce when GPU count = 1;
   } else {
     // Wait input done
diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc
index 523f9eadf2..d8526b3f24 100644
--- a/paddle/fluid/framework/details/build_strategy.cc
+++ b/paddle/fluid/framework/details/build_strategy.cc
@@ -58,10 +58,23 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
       }
     }
 
+    CollectiveContext *context = CollectiveContext::GetInstance();
+    context->endpoints_ = strategy_.trainers_endpoints_;
+    context->trainer_id_ = strategy_.trainer_id_;
+    PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
+    if (strategy_.trainer_id_ > 0) {
+      PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
+                         strategy_.trainers_endpoints_.size(),
+                     "trainer_id_ < endpoints_ size");
+    }
+    VLOG(1) << "CollectiveContext:" << context->String();
+
     // Convert graph to run on multi-devices.
     auto multi_devices_pass = AppendPass("multi_devices_pass");
     multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
                                                          &strategy_);
+    multi_devices_pass->Set<int>("num_trainers",
+                                 new int(strategy_.num_trainers_));
 
     // Add a graph print pass to record a graph with device info.
     if (!strategy_.debug_graphviz_path_.empty()) {
@@ -133,16 +146,16 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
       pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
 #endif
     } else if (pass->Type() == "sequential_execution_pass") {
-      VLOG(1) << "set enable_sequential_execution:"
-              << enable_sequential_execution_;
+      LOG(INFO) << "set enable_sequential_execution:"
+                << enable_sequential_execution_;
 
       pass->Erase(kAllOpDescs);
       pass->Set<const std::vector<OpDesc *>>(
           kAllOpDescs,
           new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
     } else if (pass->Type() == "all_reduce_deps_pass") {
-      VLOG(1) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
-              << ", num_trainers:" << num_trainers_;
+      LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
+                << ", num_trainers:" << num_trainers_;
 
       pass->Erase(kAllOpDescs);
       pass->Set<const std::vector<OpDesc *>>(
diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h
index 9f0a259128..c97be16957 100644
--- a/paddle/fluid/framework/details/build_strategy.h
+++ b/paddle/fluid/framework/details/build_strategy.h
@@ -74,6 +74,8 @@ struct BuildStrategy {
   bool fuse_broadcast_op_{false};
 
   int num_trainers_{1};
+  int trainer_id_{0};
+  std::vector<std::string> trainers_endpoints_;
   bool remove_unnecessary_lock_{false};
 
   // NOTE:
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
index 03f5f2e73a..cbae5321d9 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
@@ -133,6 +133,7 @@ static const char kPlaces[] = "places";
 static const char kParams[] = "params";
 static const char kLocalScopes[] = "local_scopes";
 static const char kStrategy[] = "strategy";
+static const char kNumTrainers[] = "num_trainers";
 
 void MultiDevSSAGraphBuilder::Init() const {
   all_vars_.clear();
@@ -299,6 +300,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
   auto nodes = graph->ReleaseNodes();
   ir::Graph &result = *graph;
 
+  int num_trainers = Get<int>(kNumTrainers);
+
   for (auto &node : nodes) {
     if (node->IsVar() && node->Var()) {
       all_vars_.emplace(node->Name(), node->Var());
@@ -383,7 +386,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
           CreateComputationalOps(&result, node, places_.size());
         }
 
-        if (!is_forwarding && places_.size() > 1) {
+        if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
           // Currently, we assume that once gradient is generated, it can be
           // broadcast, and each gradient is only broadcast once.
           if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
@@ -895,4 +898,5 @@ REGISTER_PASS(multi_devices_pass,
     .RequirePassAttr(paddle::framework::details::kPlaces)
     .RequirePassAttr(paddle::framework::details::kParams)
     .RequirePassAttr(paddle::framework::details::kLocalScopes)
-    .RequirePassAttr(paddle::framework::details::kStrategy);
+    .RequirePassAttr(paddle::framework::details::kStrategy)
+    .RequirePassAttr(paddle::framework::details::kNumTrainers);
diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h
index 1ce18c3d6b..eea7e712f8 100644
--- a/paddle/fluid/framework/details/op_registry.h
+++ b/paddle/fluid/framework/details/op_registry.h
@@ -32,9 +32,7 @@ enum OpInfoFillType {
   kOpProtoAndCheckerMaker = 1,
   kGradOpDescMaker = 2,
   kVarTypeInference = 3,
-  kShapeInference = 4,
-  kEstimateFlops = 5,
-  kUnknown = -1
+  kShapeInference = 4
 };
 
 template <typename T>
@@ -50,10 +48,8 @@ struct OpInfoFillTypeID {
                                     ? kVarTypeInference
                                     : (std::is_base_of<InferShapeBase, T>::value
                                            ? kShapeInference
-                                           : (std::is_base_of<EstimateFlopsBase,
-                                                              T>::value
-                                                  ? kEstimateFlops
-                                                  : kUnknown)))));
+                                           : static_cast<OpInfoFillType>(
+                                                 -1)))));
   }
 };
 
@@ -143,16 +139,6 @@ struct OpInfoFiller<T, kShapeInference> {
   }
 };
 
-template <typename T>
-struct OpInfoFiller<T, kEstimateFlops> {
-  void operator()(const char* op_tpe, OpInfo* info) const {
-    info->estimate_flops_ = [](InferShapeContext* ctx) {
-      T estimate_flops;
-      return estimate_flops(ctx);
-    };
-  }
-};
-
 }  // namespace details
 
 }  // namespace framework
diff --git a/paddle/fluid/framework/details/reduce_and_gather.h b/paddle/fluid/framework/details/reduce_and_gather.h
index bd6153c0c7..2e5256fbd4 100644
--- a/paddle/fluid/framework/details/reduce_and_gather.h
+++ b/paddle/fluid/framework/details/reduce_and_gather.h
@@ -53,7 +53,7 @@ struct ReduceLoDTensor {
   }
 };
 
-inline void GatherSelectedRows(
+inline void GatherLocalSelectedRows(
     const std::vector<const SelectedRows *> &src_selecte_rows_,
     const std::vector<platform::Place> &in_places,
     const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc
index c9f1107aea..cb864848b9 100644
--- a/paddle/fluid/framework/details/reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/reduce_op_handle.cc
@@ -16,6 +16,12 @@
 #include "paddle/fluid/framework/details/container_cast.h"
 #include "paddle/fluid/framework/details/reduce_and_gather.h"
 #include "paddle/fluid/framework/details/variable_visitor.h"
+#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
+#include "paddle/fluid/operators/distributed/collective_client.h"
+#include "paddle/fluid/operators/distributed/collective_server.h"
+#include "paddle/fluid/operators/distributed/request_handler.h"
+#endif
+#include "paddle/fluid/operators/math/selected_rows_functor.h"
 #include "paddle/fluid/platform/profiler.h"
 
 DEFINE_bool(
@@ -26,6 +32,112 @@ namespace paddle {
 namespace framework {
 namespace details {
 
+std::once_flag CollectiveContext::init_flag_;
+std::unique_ptr<CollectiveContext> CollectiveContext::context_;
+
+static inline std::string GetRemoteVarName(const std::string &var_name,
+                                           int trainer_id) {
+  return string::Sprintf("%s_merged_tmp@trainer_%d", var_name, trainer_id);
+}
+
+void ReduceOpHandle::Wait(
+    const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes) {
+  // TODO(gongwb): use event wait?
+  for (auto &dev_ctx : dev_ctxes) {
+    dev_ctx.second->Wait();
+  }
+}
+
+#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
+template <typename DevCtx, typename DataType>
+void ReduceOpHandle::GatherSelectedRows(
+    const std::vector<const SelectedRows *> &src_selected_rows,
+    const std::vector<platform::Place> &in_places,
+    const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
+    VarHandle *out_var_handle, const platform::Place &out_place,
+    SelectedRows *dst_selected_rows) {
+  const CollectiveContext &collective_context =
+      *CollectiveContext::GetInstance();
+
+  // 1. gather local selected rows, merge them
+  std::string gathered_var_name = out_var_handle->name_ + "_gathered_tmp";
+  auto scope = local_scopes_.at(out_var_handle->scope_idx_);
+  auto gathered_var_mid = scope->Var(gathered_var_name);
+  auto gathered_select_rows =
+      gathered_var_mid->GetMutable<framework::SelectedRows>();
+  GatherLocalSelectedRows(src_selected_rows, in_places, dev_ctxes, out_place,
+                          gathered_select_rows);
+  // FIXME(gongwb): remove this Wait.
+  Wait(dev_ctxes);
+
+  // merge them
+  auto merged_dev_ctx = dynamic_cast<DevCtx *>(dev_ctxes.at(out_place));
+  std::string merged_var_name =
+      GetRemoteVarName(out_var_handle->name_, collective_context.trainer_id_);
+  auto merged_select_rows =
+      scope->Var(merged_var_name)->GetMutable<SelectedRows>();
+  operators::math::scatter::MergeAdd<DevCtx, DataType> merge_func;
+  merge_func(*merged_dev_ctx, *gathered_select_rows, merged_select_rows);
+
+  // 2. start collective server if it doesn't exist
+  operators::distributed::CollectiveServer *server =
+      operators::distributed::CollectiveServer::GetInstance(
+          collective_context.endpoints_[collective_context.trainer_id_],
+          collective_context.endpoints_.size() - 1);
+
+  auto rpc_server = server->GetRPCServer();
+  rpc_server->RegisterVar(merged_var_name,
+                          operators::distributed::kRequestGetMonomerVariable,
+                          scope, merged_dev_ctx);
+
+  // 3. gather them from all remote nodes.
+  std::vector<const SelectedRows *> remote;
+  operators::distributed::CollectiveClient *client =
+      operators::distributed::CollectiveClient::GetInstance();
+
+  std::vector<operators::distributed::RemoteVar> vars;
+  for (unsigned int i = 0; i < collective_context.endpoints_.size(); i++) {
+    if (i == (unsigned)collective_context.trainer_id_) continue;
+
+    operators::distributed::RemoteVar var;
+    var.trainer_id_ = i;
+    var.var_name_ = GetRemoteVarName(out_var_handle->name_, i);
+    var.ep_ = collective_context.endpoints_[i];
+
+    vars.push_back(var);
+    VLOG(4) << "gather from:" << var.String();
+  }
+
+  // erase gathered vars
+  merged_dev_ctx->Wait();
+  scope->EraseVars(std::vector<std::string>{gathered_var_name});
+
+  PADDLE_ENFORCE(client->Gather(vars, &remote, *merged_dev_ctx, scope));
+  PADDLE_ENFORCE(remote.size() == vars.size());
+
+  // 4. merged local selected rows.
+  std::vector<const SelectedRows *> all;
+  all.resize(collective_context.endpoints_.size());
+  for (auto v : vars) {
+    all[v.trainer_id_] =
+        scope->FindVar(v.var_name_)->GetMutable<SelectedRows>();
+  }
+  all[collective_context.trainer_id_] = merged_select_rows;
+
+  merge_func(*merged_dev_ctx, all, dst_selected_rows);
+
+  rpc_server->WaitVarBarrier(merged_var_name);
+  rpc_server->ClearVar(merged_var_name);
+
+  // 5. clear mid vars
+  std::vector<std::string> tmp_vars{merged_var_name};
+  for (auto r : vars) {
+    tmp_vars.push_back(r.var_name_);
+  }
+  scope->EraseVars(tmp_vars);
+}
+#endif
+
 void ReduceOpHandle::RunImpl() {
   platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
 
@@ -90,8 +202,36 @@ void ReduceOpHandle::RunImpl() {
     this->RunAndRecordEvent([&] {
       std::vector<const SelectedRows *> in_selected_rows =
           GetInputValues<SelectedRows>(in_var_handles, var_scopes);
-      GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
-                         out_var->GetMutable<framework::SelectedRows>());
+
+      const CollectiveContext &collective_context =
+          *CollectiveContext::GetInstance();
+      VLOG(10) << "GatherSelectedRows CollectiveContext:"
+               << collective_context.String();
+
+      // TODO(gongwb): add cpu support
+      if (collective_context.endpoints_.size() <= 1 ||
+          is_cpu_place(in_places[0]) || is_cpu_place(t_out_p)) {
+        GatherLocalSelectedRows(in_selected_rows, in_places, dev_ctxes_,
+                                t_out_p,
+                                out_var->GetMutable<framework::SelectedRows>());
+        return;
+      }
+
+#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
+      if (framework::IsType<const float>(in_selected_rows[0]->value().type())) {
+        GatherSelectedRows<platform::CUDADeviceContext, float>(
+            in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
+            out_var->GetMutable<framework::SelectedRows>());
+      } else if (framework::IsType<const double>(
+                     in_selected_rows[0]->value().type())) {
+        GatherSelectedRows<platform::CUDADeviceContext, double>(
+            in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
+            out_var->GetMutable<framework::SelectedRows>());
+      } else {
+        PADDLE_ENFORCE(false,
+                       "only support double or float when gahter SelectedRows");
+      }
+#endif
     });
   } else {
     std::vector<const LoDTensor *> lod_tensors =
diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h
index 846839029c..5491f00f45 100644
--- a/paddle/fluid/framework/details/reduce_op_handle.h
+++ b/paddle/fluid/framework/details/reduce_op_handle.h
@@ -30,6 +30,32 @@
 namespace paddle {
 namespace framework {
 namespace details {
+struct CollectiveContext {
+  std::vector<std::string> endpoints_;
+  int trainer_id_{0};
+
+  std::string String() const {
+    std::stringstream ss;
+    ss << "endpoints_:";
+    for (auto e : endpoints_) {
+      ss << e << ",";
+    }
+
+    ss << "trainer_id_:" << trainer_id_;
+
+    return ss.str();
+  }
+
+  static CollectiveContext *GetInstance() {
+    std::call_once(init_flag_,
+                   [&]() { context_.reset(new CollectiveContext()); });
+    return context_.get();
+  }
+
+ private:
+  static std::once_flag init_flag_;
+  static std::unique_ptr<CollectiveContext> context_;
+};
 
 struct ReduceOpHandle : public OpHandleBase {
   std::vector<Scope *> local_scopes_;
@@ -64,6 +90,19 @@ struct ReduceOpHandle : public OpHandleBase {
  protected:
   void RunImpl() override;
 
+#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
+  template <typename DevCtx, typename DataType>
+  void GatherSelectedRows(
+      const std::vector<const SelectedRows *> &src_selecte_rows_,
+      const std::vector<platform::Place> &in_places,
+      const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
+      VarHandle *out_var_handle, const platform::Place &out_place,
+      SelectedRows *dst_selecte_rows);
+#endif
+
+  void Wait(
+      const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes);
+
   template <typename T>
   std::vector<const T *> GetInputValues(
       const std::vector<VarHandle *> &in_var_handles,
diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc
index 4e4001e979..3d53511615 100644
--- a/paddle/fluid/framework/executor_thread_worker.cc
+++ b/paddle/fluid/framework/executor_thread_worker.cc
@@ -97,7 +97,7 @@ void ExecutorThreadWorker::SetDevice() {
   static unsigned concurrency_cap = std::thread::hardware_concurrency();
   int thread_id = this->thread_id_;
 
-  if (thread_id < concurrency_cap) {
+  if (static_cast<unsigned>(thread_id) < concurrency_cap) {
     unsigned proc = thread_id;
 
     cpu_set_t mask;
diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
index 449cc78be1..d4a701e0b1 100644
--- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
@@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
   auto* scope = param_scope();
   PADDLE_ENFORCE(scope);
 
+  std::string type = is_conv3d() ? "conv3d" : "conv2d";
+
   GraphPatternDetector gpd;
   auto* conv_input =
       gpd.mutable_pattern()
           ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
           ->AsInput()
-          ->assert_is_op_input("conv2d", "Input");
+          ->assert_is_op_input(type, "Input");
   patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
-  conv_bias_pattern(conv_input);
+  conv_bias_pattern(conv_input, is_conv3d());
   int found_conv_bias_count = 0;
   auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                      Graph* g) {
@@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
       desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
       desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
       desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
-      desc.SetType("conv2d");
+      desc.SetType(type);
 
       for (auto& attr : conv->Op()->GetAttrMap()) {
         desc.SetAttr(attr.first, attr.second);
@@ -135,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
 }  // namespace paddle
 REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
               paddle::framework::ir::ConvBiasFusePass);
+REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
+              paddle::framework::ir::Conv3DBiasFusePass);
diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
index 5775b83b88..f3ad9f1c2b 100644
--- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
+++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
@@ -26,11 +26,19 @@ namespace ir {
 class ConvBiasFusePass : public FusePassBase {
  public:
   virtual ~ConvBiasFusePass() {}
+  virtual bool is_conv3d() const { return false; }
 
  protected:
   std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
   const std::string name_scope_{"conv_bias_mkldnn_fuse"};
 };
+/*
+* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
+*/
+class Conv3DBiasFusePass : public ConvBiasFusePass {
+ public:
+  bool is_conv3d() const override { return true; }
+};
 }  // namespace ir
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h
index 947c934f0f..bb2d953afb 100644
--- a/paddle/fluid/framework/ir/graph.h
+++ b/paddle/fluid/framework/ir/graph.h
@@ -177,14 +177,13 @@ class Graph {
     return nullptr;
   }
 
-  const ProgramDesc &program() const { return program_; }
-  std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
-      const ProgramDesc &program);
-
   void ResolveHazard(
       const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
 
  private:
+  std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
+      const ProgramDesc &program);
+
   // This method takes ownership of `node`.
   ir::Node *AddNode(ir::Node *node) {
     PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
index 258182b25a..0118019df2 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -1030,10 +1030,11 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
 }
 
 PDNode *patterns::ConvBias::operator()(
-    paddle::framework::ir::PDNode *conv_input) {
+    paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
+  std::string type = is_conv3d ? "conv3d" : "conv2d";
   // Create Operators
-  conv_input->assert_is_op_input("conv2d", "Input");
-  auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
+  conv_input->assert_is_op_input(type, "Input");
+  auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type);
   auto *eltiwse_op =
       pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
   // Create variables
@@ -1041,11 +1042,11 @@ PDNode *patterns::ConvBias::operator()(
   auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
                               ->AsInput()
                               ->assert_is_persistable_var()
-                              ->assert_is_op_input("conv2d", "Filter");
+                              ->assert_is_op_input(type, "Filter");
   // intermediate variable, will be removed in the IR after fuse.
   auto *conv_out_var = pattern->NewNode(conv_out_repr())
                            ->AsIntermediate()
-                           ->assert_is_only_output_of_op("conv2d")
+                           ->assert_is_only_output_of_op(type)
                            ->assert_is_op_input("elementwise_add");
   // Bias stored in elementwise_add
   auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
index c12b9503fd..d044802f22 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
 struct ConvBias : public PatternBase {
   ConvBias(PDPattern* pattern, const std::string& name_scope)
       : PatternBase(pattern, name_scope, "conv_bias") {}
-  PDNode* operator()(PDNode* conv_input);
+  PDNode* operator()(PDNode* conv_input, bool is_conv3d = false);
   // declare operator node's name
   PATTERN_DECL_NODE(conv);
   PATTERN_DECL_NODE(eltwise);
diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc
index 1cf1315d3d..9a9314161b 100644
--- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc
+++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
+#include <string>
 
 namespace paddle {
 namespace framework {
@@ -21,9 +22,16 @@ namespace ir {
 std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   VLOG(3) << "Aplies MKL-DNN placement strategy.";
+  const auto& op_types_list =
+      Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
   for (const Node* n : graph->Nodes()) {
     if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) {
-      n->Op()->SetAttr("use_mkldnn", true);
+      if (op_types_list.empty()) {
+        n->Op()->SetAttr("use_mkldnn", true);
+      } else if (std::find(op_types_list.begin(), op_types_list.end(),
+                           n->Name()) != op_types_list.end()) {
+        n->Op()->SetAttr("use_mkldnn", true);
+      }
     }
   }
   return graph;
@@ -33,5 +41,5 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
 }  // namespace framework
 }  // namespace paddle
 
-REGISTER_PASS(mkldnn_placement_pass,
-              paddle::framework::ir::MKLDNNPlacementPass);
+REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
+    .RequirePassAttr("mkldnn_enabled_op_types");
diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h
index e0bf5ed999..19e5c2c73e 100644
--- a/paddle/fluid/framework/op_info.h
+++ b/paddle/fluid/framework/op_info.h
@@ -31,12 +31,6 @@ class InferShapeBase {
   virtual void operator()(InferShapeContext*) const = 0;
 };
 
-class EstimateFlopsBase {
- public:
-  virtual ~EstimateFlopsBase() = default;
-  virtual size_t operator()(InferShapeContext*) const = 0;
-};
-
 struct OpInfo {
   OpCreator creator_;
   GradOpMakerFN grad_op_maker_;
@@ -44,7 +38,6 @@ struct OpInfo {
   OpAttrChecker* checker_{nullptr};
   InferVarTypeFN infer_var_type_;
   InferShapeFN infer_shape_;
-  EstimateFlopsFN estimate_flops_;
 
   bool HasOpProtoAndChecker() const {
     return proto_ != nullptr && checker_ != nullptr;
diff --git a/paddle/fluid/framework/op_kernel_type.cc b/paddle/fluid/framework/op_kernel_type.cc
new file mode 100644
index 0000000000..6d4801e4a0
--- /dev/null
+++ b/paddle/fluid/framework/op_kernel_type.cc
@@ -0,0 +1,54 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_kernel_type.h"
+
+namespace paddle {
+namespace framework {
+
+size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
+  int cur_loc = 0;
+
+  int place = key.place_.which();
+  cur_loc += OpKernelType::kPlaceBits;
+
+  int data_type = static_cast<int>(key.data_type_) << cur_loc;
+  cur_loc += OpKernelType::kPrimaryDTypeBits;
+
+  int data_layout = static_cast<int>(key.data_layout_) << cur_loc;
+  cur_loc += OpKernelType::kLayoutBits;
+
+  int library_type = static_cast<int>(key.library_type_) << cur_loc;
+  cur_loc += OpKernelType::kLibBits;
+
+  int customized_value = key.customized_type_value_;
+  PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits));
+  customized_value = customized_value << cur_loc;
+  cur_loc += OpKernelType::kCustomizeBits;
+  PADDLE_ENFORCE(cur_loc < 64);
+
+  std::hash<int> hasher;
+  return hasher(place + data_type + data_layout + library_type +
+                customized_value);
+}
+
+bool OpKernelType::operator==(const OpKernelType& o) const {
+  return platform::places_are_same_class(place_, o.place_) &&
+         data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
+         library_type_ == o.library_type_ &&
+         customized_type_value_ == o.customized_type_value_;
+}
+
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h
index ac03302189..9edc1a3e15 100644
--- a/paddle/fluid/framework/op_kernel_type.h
+++ b/paddle/fluid/framework/op_kernel_type.h
@@ -24,54 +24,55 @@ limitations under the License. */
 namespace paddle {
 namespace framework {
 
-struct OpKernelType {
-  struct Hash {
-    size_t operator()(const OpKernelType& key) const {
-      int place = key.place_.which();
-      int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
-      int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
-      int library_type = static_cast<int>(key.library_type_)
-                         << (LEFT_SHIFT * 3);
-
-      std::hash<int> hasher;
-      return hasher(place + data_type + data_layout + library_type);
-    }
-  };
+class OpKernelType {
+ public:
+  constexpr static int kDefaultCustomizedTypeValue = 0;
 
-  // place, data_type, library_type kinds less than 2^8
-  constexpr static int LEFT_SHIFT = 8;
-
-  proto::VarType::Type data_type_;
-  DataLayout data_layout_;
-  platform::Place place_;
-  LibraryType library_type_;
+  // In total should be smaller than 64.
+  constexpr static int kPlaceBits = 4;
+  constexpr static int kPrimaryDTypeBits = 8;
+  constexpr static int kLayoutBits = 4;
+  constexpr static int kLibBits = 4;
+  constexpr static int kCustomizeBits = 4;
 
   OpKernelType(proto::VarType::Type data_type, platform::Place place,
                DataLayout data_layout = DataLayout::kAnyLayout,
-               LibraryType library_type = LibraryType::kPlain)
+               LibraryType library_type = LibraryType::kPlain,
+               int customized_type_value = kDefaultCustomizedTypeValue)
       : data_type_(data_type),
         data_layout_(data_layout),
         place_(place),
-        library_type_(library_type) {}
+        library_type_(library_type),
+        customized_type_value_(customized_type_value) {}
 
   OpKernelType(proto::VarType::Type data_type,
                const platform::DeviceContext& dev_ctx,
                DataLayout data_layout = DataLayout::kAnyLayout,
-               LibraryType library_type = LibraryType::kPlain)
+               LibraryType library_type = LibraryType::kPlain,
+               int customized_type_value = kDefaultCustomizedTypeValue)
       : data_type_(data_type),
         data_layout_(data_layout),
         place_(dev_ctx.GetPlace()),
-        library_type_(library_type) {}
+        library_type_(library_type),
+        customized_type_value_(customized_type_value) {}
+
+  virtual ~OpKernelType() {}
+
+  struct Hash {
+    size_t operator()(const OpKernelType& key) const;
+  };
 
   size_t hash_key() const { return Hash()(*this); }
 
-  bool operator==(const OpKernelType& o) const {
-    return platform::places_are_same_class(place_, o.place_) &&
-           data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
-           library_type_ == o.library_type_;
-  }
+  bool operator==(const OpKernelType& o) const;
 
   bool operator!=(const OpKernelType& o) const { return !(*this == o); }
+
+  proto::VarType::Type data_type_;
+  DataLayout data_layout_;
+  platform::Place place_;
+  LibraryType library_type_;
+  int customized_type_value_;
 };
 
 inline std::ostream& operator<<(std::ostream& os,
diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h
index 0e6e74293c..36673e48c2 100644
--- a/paddle/fluid/framework/op_registry.h
+++ b/paddle/fluid/framework/op_registry.h
@@ -35,6 +35,7 @@ limitations under the License. */
 
 namespace paddle {
 namespace framework {
+
 class Registrar {
  public:
   // In our design, various kinds of classes, e.g., operators and kernels,
@@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
 
 template <typename PlaceType, typename T, typename Func>
 inline void RegisterKernelClass(const char* op_type, const char* library_type,
-                                Func func) {
+                                int customized_type_value, Func func) {
   std::string library(library_type);
   std::string data_layout = "ANYLAYOUT";
   if (library == "MKLDNN") {
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
   }
   OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
                    StringToDataLayout(data_layout),
-                   StringToLibraryType(library_type));
+                   StringToLibraryType(library_type), customized_type_value);
   OperatorWithKernel::AllOpKernels()[op_type][key] = func;
 }
 
@@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
   using KERNEL_TYPE =
       typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
 
-  void operator()(const char* op_type, const char* library_type) const {
+  void operator()(const char* op_type, const char* library_type,
+                  int customized_type_value) const {
     using T = typename KERNEL_TYPE::ELEMENT_TYPE;
     RegisterKernelClass<PlaceType, T>(
-        op_type, library_type, [](const framework::ExecutionContext& ctx) {
+        op_type, library_type, customized_type_value,
+
+        [](const framework::ExecutionContext& ctx) {
           KERNEL_TYPE().Compute(ctx);
         });
     constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
     OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
         func;
-    func(op_type, library_type);
+    func(op_type, library_type, customized_type_value);
   }
 };
 
 template <typename PlaceType, size_t I, typename... KernelType>
 struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
-  void operator()(const char* op_type, const char* library_type) const {}
+  void operator()(const char* op_type, const char* library_type,
+                  int customized_type_value) const {}
 };
 
 // User can register many kernel in one place. The data type could be
@@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
 template <typename PlaceType, typename... KernelType>
 class OpKernelRegistrar : public Registrar {
  public:
-  explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
+  explicit OpKernelRegistrar(const char* op_type, const char* library_type,
+                             int customized_type_value) {
     OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
-    func(op_type, library_type);
+    func(op_type, library_type, customized_type_value);
   }
 };
 
@@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
 template <typename PlaceType, typename... DataTypeAndKernelType>
 class OpKernelRegistrarEx : public Registrar {
  public:
-  explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
+  explicit OpKernelRegistrarEx(const char* op_type, const char* library_type,
+                               int customized_type_value) {
     OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
         func;
-    func(op_type, library_type);
+    func(op_type, library_type, customized_type_value);
   }
 };
 
 template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
 struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
                                   DataTypeAndKernelType...> {
-  void operator()(const char* op_type, const char* library_type) const {}
+  void operator()(const char* op_type, const char* library_type,
+                  int customized_type_value) const {}
 };
 
 template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
@@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
       typename std::tuple_element<I,
                                   std::tuple<DataTypeAndKernelType...>>::type;
 
-  void operator()(const char* op_type, const char* library_type) const {
-    RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
+  void operator()(const char* op_type, const char* library_type,
+                  int customized_type_value) const {
+    RegisterKernelClass<PlaceType, T>(op_type, library_type,
+                                      customized_type_value, Functor());
 
     constexpr auto size =
         std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
     OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
                                DataTypeAndKernelType...>
         func;
-    func(op_type, library_type);
+    func(op_type, library_type, customized_type_value);
   }
 };
 
+// clang-format off
 /**
  * check if MACRO is used in GLOBAL NAMESPACE.
  */
@@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
 /**
  * Macro to register OperatorKernel.
  */
-#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)        \
-  STATIC_ASSERT_GLOBAL_NAMESPACE(                                          \
-      __reg_op_kernel_##op_type##_##library_type##__,                      \
-      "REGISTER_OP_KERNEL must be called in global namespace");            \
-  static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__>  \
-      __op_kernel_registrar_##op_type##_##library_type##__(#op_type,       \
-                                                           #library_type); \
-  int TouchOpKernelRegistrar_##op_type##_##library_type() {                \
-    __op_kernel_registrar_##op_type##_##library_type##__.Touch();          \
-    return 0;                                                              \
+#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type,             \
+                                            place_class, customized_name,      \
+                                            customized_type_value, ...)        \
+  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
+      __reg_op_kernel_##op_type##_##library_type##_##customized_name##__,      \
+                                 "REGISTER_OP_KERNEL must be called in "       \
+                                 "global namespace");                          \
+  static ::paddle::framework::OpKernelRegistrar<place_class,                   \
+                                                __VA_ARGS__>                   \
+      __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
+          #op_type, #library_type, customized_type_value);                     \
+  int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
+    __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__   \
+        .Touch();                                                              \
+    return 0;                                                                  \
   }
 
+#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)   \
+  REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(                                \
+      op_type, library_type, place_class, DEFAULT_TYPE,               \
+      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
+      __VA_ARGS__)
+
 #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
   REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
 
 #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
   REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
 
-#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...)      \
-  STATIC_ASSERT_GLOBAL_NAMESPACE(                                           \
-      __reg_op_kernel_##op_type##_##library_type##__,                       \
-      "REGISTER_OP_KERNEL_EX must be called in global namespace");          \
-  static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
-      __op_kernel_registrar_##op_type##_##library_type##__(#op_type,        \
-                                                           #library_type);  \
-  int TouchOpKernelRegistrar_##op_type##_##library_type() {                 \
-    __op_kernel_registrar_##op_type##_##library_type##__.Touch();           \
-    return 0;                                                               \
+#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class,  \
+                              customized_name,                     \
+                              customized_type_value,               \
+                              ...)                                 \
+  STATIC_ASSERT_GLOBAL_NAMESPACE(                                  \
+      __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
+                                 "REGISTER_OP_KERNEL_EX must be called in "  \
+                                 "global namespace");  \
+  static ::paddle::framework::OpKernelRegistrarEx<place_class,  \
+                                                  __VA_ARGS__>  \
+      __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
+          #op_type, #library_type, customized_type_value);  \
+  int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
+    __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__   \
+        .Touch();                                                              \
+    return 0;                                                                  \
   }
 
 #define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...)                 \
-  REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
-                        __VA_ARGS__)
+  REGISTER_OP_KERNEL_EX(                                              \
+      op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE,     \
+      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
+      __VA_ARGS__)
 
-#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
-  REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
+#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...)                  \
+  REGISTER_OP_KERNEL_EX(                                              \
+      op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE,       \
+      ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
+      __VA_ARGS__)
 
 /**
  * Macro to mark what Operator and Kernel
@@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
   extern int TouchOpRegistrar_##op_type();                 \
   UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
 
-#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE)               \
-  STATIC_ASSERT_GLOBAL_NAMESPACE(                                 \
-      __use_op_kernel_##op_type##_##LIBRARY_TYPE##__,             \
-      "USE_OP_DEVICE_KERNEL must be in global namespace");        \
-  extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \
-  UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_ = \
-      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE()
+#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type,                     \
+                                              LIBRARY_TYPE,                \
+                                              customized_name)             \
+  STATIC_ASSERT_GLOBAL_NAMESPACE(                                          \
+      __use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__,  \
+      "USE_OP_DEVICE_KERNEL must be in global namespace");                 \
+  extern int                                                               \
+      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
+  UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ = /* NOLINT */ \
+      TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
+
+#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
+  USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
 
 // TODO(fengjiayi): The following macros
 // seems ugly, do we have better method?
@@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
 #define USE_OP(op_type)   \
   USE_OP_ITSELF(op_type); \
   USE_OP_KERNEL(op_type)
+// clang-format off
 
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc
index ac9dd8245a..ab14732e4d 100644
--- a/paddle/fluid/framework/operator_test.cc
+++ b/paddle/fluid/framework/operator_test.cc
@@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
     AddInput("input", "input of test op");
     AddOutput("output", "output of test op");
     AddAttr<float>("scale", "scale of cosine op");
+    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
+        .SetDefault(0);
     AddComment("This is test op");
   }
 };
@@ -95,6 +97,8 @@ TEST(OperatorBase, all) {
 namespace paddle {
 namespace framework {
 
+static int special_type_value = 1;
+
 class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
  public:
   void Make() {
@@ -103,11 +107,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
     AddAttr<float>("scale", "scale of cosine op")
         .SetDefault(1.0)
         .GreaterThan(0.0);
+    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
+        .SetDefault(0);
     AddComment("This is test op");
   }
 };
 
 static int cpu_kernel_run_num = 0;
+static int cpu_kernel2_run_num = 0;
 
 class OpWithKernelTest : public OperatorWithKernel {
  public:
@@ -117,7 +124,10 @@ class OpWithKernelTest : public OperatorWithKernel {
   void InferShape(framework::InferShapeContext* ctx) const override {}
   OpKernelType GetExpectedKernelType(
       const ExecutionContext& ctx) const override {
-    return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
+    int sub_type = ctx.Attr<int>("kernel_sub_type");
+    return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
+                        framework::DataLayout::kAnyLayout,
+                        framework::LibraryType::kPlain, sub_type);
   }
 };
 
@@ -132,6 +142,17 @@ class CPUKernelTest : public OpKernel<float> {
   }
 };
 
+template <typename T1, typename T2>
+class CPUKernel2Test : public OpKernel<float> {
+ public:
+  void Compute(const ExecutionContext& ctx) const {
+    std::cout << ctx.op().DebugString() << std::endl;
+    cpu_kernel2_run_num++;
+    ASSERT_EQ(ctx.op().Input("x"), "IN1");
+    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
+  }
+};
+
 class OpKernelTestMultiInputsProtoAndCheckerMaker
     : public OpProtoAndCheckerMaker {
  public:
@@ -142,6 +163,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
     AddAttr<float>("scale", "scale of cosine op")
         .SetDefault(1.0)
         .GreaterThan(0.0);
+    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
+        .SetDefault(0);
     AddComment("This is test op");
   }
 };
@@ -189,9 +212,15 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
 REGISTER_OP_WITHOUT_GRADIENT(
     op_with_kernel, paddle::framework::OpWithKernelTest,
     paddle::framework::OpKernelTestProtoAndCheckerMaker);
+
 REGISTER_OP_CPU_KERNEL(op_with_kernel,
                        paddle::framework::CPUKernelTest<float, float>);
 
+REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
+    op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME,
+    paddle::framework::special_type_value,
+    paddle::framework::CPUKernel2Test<float, float>);
+
 // test with single input
 TEST(OpKernel, all) {
   paddle::framework::InitDevices(true);
@@ -211,7 +240,19 @@ TEST(OpKernel, all) {
   auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
   ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
   op->Run(scope, cpu_place);
+  // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
+  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
+  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);
+
+  attr = op_desc.mutable_attrs()->Add();
+  attr->set_name("kernel_sub_type");
+  attr->set_type(paddle::framework::proto::AttrType::INT);
+  attr->set_i(1);
+  auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
+  op2->Run(scope, cpu_place);
+  // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
   ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
+  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
 }
 
 REGISTER_OP_WITHOUT_GRADIENT(
diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h
index 44384082db..e1bdba9b46 100644
--- a/paddle/fluid/framework/selected_rows.h
+++ b/paddle/fluid/framework/selected_rows.h
@@ -32,8 +32,7 @@ namespace framework {
 class SelectedRows {
   /*
    * @brief We can use the SelectedRows structure to reproduce a sparse table.
-   *  A sparse table is a key-value structure that the key is an `int64_t`
-   * number,
+   *  A sparse table is a key-value structure that the key is an `int64_t`,
    *  and the value is a Tensor which the first dimension is 0.
    *  You can use the following interface to operate the sparse table, and you
    * can find
diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h
index cdc5fa6862..2de6233a9e 100644
--- a/paddle/fluid/framework/type_defs.h
+++ b/paddle/fluid/framework/type_defs.h
@@ -54,7 +54,5 @@ using InferVarTypeFN =
 
 using InferShapeFN = std::function<void(InferShapeContext*)>;
 
-using EstimateFlopsFN = std::function<void(InferShapeContext*)>;
-
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h
index 21203e2d9f..83d411eecf 100644
--- a/paddle/fluid/inference/analysis/argument.h
+++ b/paddle/fluid/inference/analysis/argument.h
@@ -103,6 +103,7 @@ struct Argument {
   // Model specified with program and parameters files.
   DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
   DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
+  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
 
   // The overall graph to work on.
   DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
@@ -115,6 +116,10 @@ struct Argument {
   DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses,
                       std::vector<std::string>);
 
+  // Pass a set of op types to enable its mkldnn kernel
+  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
+                      std::unordered_set<std::string>);
+
   DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
   DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
   DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc
index fce5e1cac9..51bca8039d 100644
--- a/paddle/fluid/inference/analysis/ir_pass_manager.cc
+++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc
@@ -63,6 +63,11 @@ void IRPassManager::CreatePasses(Argument *argument,
       pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
       pass_num++;
     }
+    if (pass_name == "mkldnn_placement_pass") {
+      pass->Set("mkldnn_enabled_op_types",
+                new std::unordered_set<std::string>(
+                    argument->mkldnn_enabled_op_types()));
+    }
 
     if (pass_name == "tensorrt_subgraph_pass") {
       PADDLE_ENFORCE(argument->tensorrt_node_teller_valid());
diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
index c6b7c05f78..4ffe5f575c 100644
--- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
@@ -178,11 +178,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
     output_mapping.push_back(output_name_map[name]);
   }
 
-  *block_desc.Proto()->mutable_vars() =
-      const_cast<framework::ProgramDesc *>(&graph->program())
-          ->Proto()
-          ->blocks(0)
-          .vars();
+  auto *vars = block_desc.Proto()->mutable_vars();
+  for (framework::ir::Node *node : graph->Nodes()) {
+    if (node->IsVar() && node->Var()) {
+      *vars->Add() = *node->Var()->Proto();
+    }
+  }
   PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
                  "the block has no var-desc");
   PADDLE_ENFORCE(!output_mapping.empty());
diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
index 740030c3a8..b8a045c18f 100644
--- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
+++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
@@ -46,7 +46,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
              argument->model_params_path_valid()) {
     auto program =
         LoadModel(argument->model_program_path(), argument->model_params_path(),
-                  argument->scope_ptr(), place);
+                  argument->scope_ptr(), place, argument->model_from_memory());
     argument->SetMainProgram(program.release());
   } else {
     PADDLE_THROW(
@@ -68,9 +68,14 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
 
 std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
     const std::string &program_path, const std::string &params_path,
-    framework::Scope *scope, const platform::Place &place) {
+    framework::Scope *scope, const platform::Place &place,
+    bool model_from_memory) {
   framework::Executor exe(place);
-  return Load(&exe, scope, program_path, params_path);
+  if (!model_from_memory) {
+    return Load(&exe, scope, program_path, params_path);
+  } else {
+    return LoadFromMemory(&exe, scope, program_path, params_path);
+  }
 }
 
 std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; }
diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h
index 271e64fce5..adbde0433f 100644
--- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h
+++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h
@@ -24,7 +24,7 @@ namespace inference {
 namespace analysis {
 
 /*
- * Load program and parameter to memory from the disk.
+ * Load program and parameter to memory from the disk or directly from memory.
  */
 class IrGraphBuildPass : public AnalysisPass {
  public:
@@ -38,7 +38,8 @@ class IrGraphBuildPass : public AnalysisPass {
       const platform::Place &place);
   std::unique_ptr<framework::ProgramDesc> LoadModel(
       const std::string &program_path, const std::string &params_path,
-      framework::Scope *scope, const platform::Place &place);
+      framework::Scope *scope, const platform::Place &place,
+      bool model_from_memory);
 
   std::string model_binary_str_;
 };
diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc
index dd75f0d9a6..dcefdd92f5 100644
--- a/paddle/fluid/inference/api/analysis_config.cc
+++ b/paddle/fluid/inference/api/analysis_config.cc
@@ -49,10 +49,15 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
   cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
   // fields from this.
   enable_ir_optim = other.enable_ir_optim;
+  // For mkldnn
+  use_mkldnn_ = other.use_mkldnn_;
+  mkldnn_enabled_op_types_ = other.mkldnn_enabled_op_types_;
+
   use_feed_fetch_ops = other.use_feed_fetch_ops;
   use_tensorrt_ = other.use_tensorrt_;
   tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
   tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
+  model_from_memory_ = other.model_from_memory_;
 
   if (use_gpu) {
     pass_builder_.reset(new GpuPassStrategy(
@@ -76,10 +81,16 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
   cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
   // fields from this.
   enable_ir_optim = other.enable_ir_optim;
+  // For mkldnn
+  use_mkldnn_ = other.use_mkldnn_;
+  mkldnn_enabled_op_types_ = other.mkldnn_enabled_op_types_;
+
   use_feed_fetch_ops = other.use_feed_fetch_ops;
   use_tensorrt_ = other.use_tensorrt_;
   tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
   tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
+  model_from_memory_ = other.model_from_memory_;
+
   pass_builder_ = std::move(other.pass_builder_);
 }
 
@@ -102,4 +113,13 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
   pass_builder()->InsertPass(1, "tensorrt_subgraph_pass");
 }
 
+void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
+                                             size_t prog_buffer_size,
+                                             const char *param_buffer,
+                                             size_t param_buffer_size) {
+  prog_file = std::string(prog_buffer, prog_buffer + prog_buffer_size);
+  param_file = std::string(param_buffer, param_buffer + param_buffer_size);
+  model_from_memory_ = true;
+}
+
 }  // namespace paddle
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
index 391330a7c0..be51e7fc1f 100644
--- a/paddle/fluid/inference/api/analysis_predictor.cc
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -308,6 +308,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
 
   argument_.SetUseGPU(config_.use_gpu);
   argument_.SetGPUDeviceId(config_.device);
+  argument_.SetModelFromMemory(config_.model_from_memory_);
   // Analyze inference_program
   if (!config_.model_dir.empty()) {
     argument_.SetModelDir(config_.model_dir);
@@ -326,6 +327,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
     argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
   }
 
+  if (config_.use_mkldnn_) {
+    argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
+  }
+
   auto passes = config_.pass_builder()->AllPasses();
   if (!config_.enable_ir_optim) passes.clear();
   argument_.SetIrAnalysisPasses(passes);
@@ -448,20 +453,24 @@ bool AnalysisPredictor::LoadProgramDesc() {
     return false;
   }
 
-  std::string pb_content;
-  // Read binary
-  std::ifstream fin(filename, std::ios::in | std::ios::binary);
-  PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
-  fin.seekg(0, std::ios::end);
-
-  pb_content.resize(fin.tellg());
-  fin.seekg(0, std::ios::beg);
-  fin.read(&(pb_content.at(0)), pb_content.size());
-  fin.close();
-
   // Create ProgramDesc
   framework::proto::ProgramDesc proto;
-  proto.ParseFromString(pb_content);
+  if (!config_.model_from_memory()) {
+    std::string pb_content;
+    // Read binary
+    std::ifstream fin(filename, std::ios::in | std::ios::binary);
+    PADDLE_ENFORCE(static_cast<bool>(fin.is_open()), "Cannot open file %s",
+                   filename);
+    fin.seekg(0, std::ios::end);
+    pb_content.resize(fin.tellg());
+    fin.seekg(0, std::ios::beg);
+    fin.read(&(pb_content.at(0)), pb_content.size());
+    fin.close();
+
+    proto.ParseFromString(pb_content);
+  } else {
+    proto.ParseFromString(config_.prog_file);
+  }
   inference_program_.reset(new framework::ProgramDesc(proto));
   return true;
 }
@@ -469,6 +478,7 @@ bool AnalysisPredictor::LoadProgramDesc() {
 bool AnalysisPredictor::LoadParameters() {
   PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
                           "The inference program should be loaded first.");
+
   const auto &global_block = inference_program_->MutableBlock(0);
 
   // create a temporary program to load parameters.
diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h
index a09bd1cac2..f05b9832da 100644
--- a/paddle/fluid/inference/api/paddle_analysis_config.h
+++ b/paddle/fluid/inference/api/paddle_analysis_config.h
@@ -16,6 +16,7 @@
 #include <cassert>
 #include <memory>
 #include <string>
+#include <unordered_set>
 #include <vector>
 
 // Here we include some header files with relative paths, for that in deploy,
@@ -52,18 +53,26 @@ struct AnalysisConfig : public NativeConfig {
   bool use_tensorrt() const { return use_tensorrt_; }
 
   void EnableMKLDNN();
-  // NOTE this is just for internal development, please not use it.
-  // NOT stable yet.
   bool use_mkldnn() const { return use_mkldnn_; }
+  void SetMKLDNNOp(std::unordered_set<std::string> op_list) {
+    mkldnn_enabled_op_types_ = op_list;
+  }
+
+  // Specify the memory buffer of program and parameter
+  void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size,
+                      const char* program_buffer, size_t program_buffer_size);
+  bool model_from_memory() const { return model_from_memory_; }
 
   friend class ::paddle::AnalysisPredictor;
 
  protected:
   bool use_tensorrt_{false};
   bool use_mkldnn_{false};
+  std::unordered_set<std::string> mkldnn_enabled_op_types_;
   int tensorrt_workspace_size_;
   int tensorrt_max_batchsize_;
   std::unique_ptr<PassStrategy> pass_builder_;
+  bool model_from_memory_{false};
 };
 
 // Configurations for Anakin engine.
diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h
index 825bee833b..bc5139a7e5 100644
--- a/paddle/fluid/inference/api/paddle_pass_builder.h
+++ b/paddle/fluid/inference/api/paddle_pass_builder.h
@@ -98,9 +98,10 @@ class CpuPassStrategy : public PassStrategy {
     passes_.insert(passes_.begin(), "mkldnn_placement_pass");
 
     for (auto &pass :
-         std::vector<std::string>({"depthwise_conv_mkldnn_pass",  //
-                                   "conv_bias_mkldnn_fuse_pass",  //
-                                   "conv_relu_mkldnn_fuse_pass",  //
+         std::vector<std::string>({"depthwise_conv_mkldnn_pass",    //
+                                   "conv_bias_mkldnn_fuse_pass",    //
+                                   "conv3d_bias_mkldnn_fuse_pass",  //
+                                   "conv_relu_mkldnn_fuse_pass",    //
                                    "conv_elementwise_add_mkldnn_fuse_pass"})) {
       passes_.push_back(pass);
     }
diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc
index 31f43bfdca..24d15f12f9 100644
--- a/paddle/fluid/inference/io.cc
+++ b/paddle/fluid/inference/io.cc
@@ -69,7 +69,8 @@ bool IsPersistable(const framework::VarDesc* var) {
 void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
                       const framework::ProgramDesc& main_program,
                       const std::string& dirname,
-                      const std::string& param_filename) {
+                      const std::string& param_filename,
+                      bool model_from_memory = false) {
   const framework::BlockDesc& global_block = main_program.Block(0);
 
   framework::ProgramDesc* load_program = new framework::ProgramDesc();
@@ -108,6 +109,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
     op->SetType("load_combine");
     op->SetOutput("Out", paramlist);
     op->SetAttr("file_path", {param_filename});
+    op->SetAttr("model_from_memory", {model_from_memory});
     op->CheckAttrs();
   }
 
@@ -130,16 +132,17 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                  "model version %ld is not supported.",
                  main_program->Version());
 
-  LoadPersistables(executor, scope, *main_program, dirname, "");
+  // model_from_memory is false in seperate parameters.
+  LoadPersistables(executor, scope, *main_program, dirname, "",
+                   false /* model_from_memory */);
   return main_program;
 }
 
 std::unique_ptr<framework::ProgramDesc> Load(
     framework::Executor* executor, framework::Scope* scope,
     const std::string& prog_filename, const std::string& param_filename) {
-  std::string model_filename = prog_filename;
   std::string program_desc_str;
-  ReadBinaryFile(model_filename, &program_desc_str);
+  ReadBinaryFile(prog_filename, &program_desc_str);
 
   std::unique_ptr<framework::ProgramDesc> main_program(
       new framework::ProgramDesc(program_desc_str));
@@ -147,7 +150,22 @@ std::unique_ptr<framework::ProgramDesc> Load(
                  "model version %ld is not supported.",
                  main_program->Version());
 
-  LoadPersistables(executor, scope, *main_program, "", param_filename);
+  LoadPersistables(executor, scope, *main_program, "", param_filename,
+                   false /* model_from_memory */);
+  return main_program;
+}
+
+std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
+    framework::Executor* executor, framework::Scope* scope,
+    const std::string& prog_buffer, const std::string& param_buffer) {
+  std::unique_ptr<framework::ProgramDesc> main_program(
+      new framework::ProgramDesc(prog_buffer));
+  PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
+                 "model version %ld is not supported.",
+                 main_program->Version());
+
+  LoadPersistables(executor, scope, *main_program, "", param_buffer,
+                   true /* model_filename */);
   return main_program;
 }
 
diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h
index ab492577c1..317ef9d93a 100644
--- a/paddle/fluid/inference/io.h
+++ b/paddle/fluid/inference/io.h
@@ -30,7 +30,8 @@ void Init(const std::vector<std::string> argv);
 void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
                       const framework::ProgramDesc& main_program,
                       const std::string& dirname,
-                      const std::string& param_filename);
+                      const std::string& param_filename,
+                      bool model_from_memory);
 
 std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                                              framework::Scope* scope,
@@ -41,6 +42,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
                                              const std::string& prog_filename,
                                              const std::string& param_filename);
 
+std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
+    framework::Executor* executor, framework::Scope* scope,
+    const std::string& prog_buffer, const std::string& param_buffer);
+
 // Save the variables from a scope to disk.
 void SaveVars(const framework::Scope& scope,
               const std::vector<std::string>& vars, const std::string& dirname,
diff --git a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc
index 453f222f1f..b086c910d3 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc
@@ -90,5 +90,4 @@ TEST(prelu_op, test_scalar) {
 }  // namespace inference
 }  // namespace paddle
 
-// USE_OP(prelu);
-USE_CPU_ONLY_OP(prelu);
+USE_OP(prelu);
diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
index e822785ad6..95443e8133 100644
--- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
@@ -1,4 +1,4 @@
 nv_library(tensorrt_plugin
            SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
            avg_pool_op_plugin.cu
-           DEPS enforce tensorrt_engine)
+           DEPS enforce tensorrt_engine prelu)
diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
index e8f4254402..3075e87ea6 100644
--- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
+++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
@@ -14,92 +14,16 @@
 
 #include <stdio.h>
 #include <cassert>
+#include <vector>
 #include "glog/logging.h"
 #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
+#include "paddle/fluid/operators/math/prelu.h"
 
 namespace paddle {
 namespace inference {
 namespace tensorrt {
 namespace plugin {
 
-static const int CUDA_NUM_THREADS = 1024;
-static const int CUDA_MAX_NUM_BLOCKS = 65535;
-inline static int GET_NUM_BLOCKS(const int N) {
-  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
-}
-
-__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
-                                       float *output, int channel,
-                                       size_t spatial_size) {
-  size_t offset = blockIdx.x * spatial_size;
-  const float *in = input + offset;
-  float *out = output + offset;
-  float scale = alpha[blockIdx.x % channel];
-
-  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
-    float x = in[i];
-    out[i] = (x > 0) ? x : scale * x;
-  }
-}
-
-__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
-                                       float *output, size_t spatial_size) {
-  size_t offset = blockIdx.x * spatial_size;
-  const float *in = input + offset;
-  const float *scale = alpha + offset;
-  float *out = output + offset;
-
-  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
-    float x = in[i];
-    out[i] = (x > 0) ? x : scale[i] * x;
-  }
-}
-
-__global__ void PReluScalarKernel(const float *input, const float *alpha,
-                                  float *output, size_t spatial_size) {
-  size_t offset = blockIdx.x * spatial_size;
-  const float *in = input + offset;
-  float scale = *alpha;
-  float *out = output + offset;
-
-  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
-    float x = in[i];
-    out[i] = (x > 0) ? x : scale * x;
-  }
-}
-
-static inline void PReluChannelWise(cudaStream_t stream, const float *input,
-                                    const float *alpha, float *output,
-                                    int batch_size,
-                                    const nvinfer1::Dims &dims) {
-  size_t unroll = batch_size * dims.d[0];
-  size_t spatial_size = dims.d[1] * dims.d[2];
-  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
-  PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
-      input, alpha, output, dims.d[0], spatial_size);
-}
-
-static inline void PReluElementWise(cudaStream_t stream, const float *input,
-                                    const float *alpha, float *output,
-                                    int batch_size,
-                                    const nvinfer1::Dims &dims) {
-  size_t unroll = batch_size * dims.d[0];
-  size_t spatial_size = dims.d[1] * dims.d[2];
-  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
-  PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
-      input, alpha, output, spatial_size);
-}
-
-static inline void PReluScalar(cudaStream_t stream, const float *input,
-                               const float *alpha, float *output,
-                               int batch_size, const nvinfer1::Dims &dims) {
-  size_t unroll = batch_size * dims.d[0];
-  size_t spatial_size = dims.d[1] * dims.d[2];
-  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
-  PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
-      input, alpha, output, spatial_size);
-}
-
 nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
                                                 const nvinfer1::Dims *inputDims,
                                                 int nbInputs) {
@@ -110,19 +34,31 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
   return output_dims;
 }
 
-int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
+int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
                          void **outputs, void *workspace, cudaStream_t stream) {
   // input dims is CHW.
   const auto &input_dims = this->getInputDims(0);
   const float *input = reinterpret_cast<const float *>(inputs[0]);
   const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
   float *output = reinterpret_cast<float **>(outputs)[0];
+
+  std::vector<int> input_shape;
+  input_shape.push_back(batch_size);
+  for (int i = 0; i < input_dims.nbDims; i++) {
+    input_shape.push_back(input_dims.d[i]);
+  }
+
   if (mode_ == "channel") {
-    PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
+    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
+        prelu_channel_wise;
+    prelu_channel_wise(stream, input, alpha, output, input_shape);
   } else if (mode_ == "element") {
-    PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
+    operators::math::PreluElementWiseDirectCUDAFunctor<float>
+        prelu_element_wise;
+    prelu_element_wise(stream, input, alpha, output, input_shape);
   } else {
-    PReluScalar(stream, input, alpha, output, batchSize, input_dims);
+    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
+    prelu_scalar(stream, input, alpha, output, input_shape);
   }
   return cudaGetLastError() != cudaSuccess;
 }
diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
index a3a6130db7..227e2ff458 100644
--- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
@@ -188,10 +188,16 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
 }
 
 // Easy for profiling independently.
-TEST(Analyzer_dam, profile) {
+void profile(bool use_mkldnn = false) {
   contrib::AnalysisConfig cfg;
   SetConfig(&cfg);
 
+  if (use_mkldnn) {
+    cfg.EnableMKLDNN();
+    std::unordered_set<std::string> op_list = {"conv3d"};
+    cfg.SetMKLDNNOp(op_list);
+  }
+
   std::vector<PaddleTensor> outputs;
   std::vector<std::vector<PaddleTensor>> input_slots_all;
   SetInput(&input_slots_all);
@@ -209,6 +215,11 @@ TEST(Analyzer_dam, profile) {
   }
 }
 
+TEST(Analyzer_dam, profile) { profile(); }
+#ifdef PADDLE_WITH_MKLDNN
+TEST(Analyzer_dam, profile_mkldnn) { profile(true /* use_mkldnn */); }
+#endif
+
 // Check the fuse status
 TEST(Analyzer_dam, fuse_statis) {
   contrib::AnalysisConfig cfg;
@@ -222,9 +233,14 @@ TEST(Analyzer_dam, fuse_statis) {
 }
 
 // Compare result of NativeConfig and AnalysisConfig
-TEST(Analyzer_dam, compare) {
-  contrib::AnalysisConfig cfg;
+void compare(bool use_mkldnn = false) {
+  AnalysisConfig cfg;
   SetConfig(&cfg);
+  if (use_mkldnn) {
+    cfg.EnableMKLDNN();
+    std::unordered_set<std::string> op_list = {"conv3d"};
+    cfg.SetMKLDNNOp(op_list);
+  }
 
   std::vector<std::vector<PaddleTensor>> input_slots_all;
   SetInput(&input_slots_all);
@@ -233,5 +249,10 @@ TEST(Analyzer_dam, compare) {
       reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
 }
 
+TEST(Analyzer_dam, compare) { compare(); }
+#ifdef PADDLE_WITH_MKLDNN
+TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); }
+#endif
+
 }  // namespace inference
 }  // namespace paddle
diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
index 3a5f844de3..66d85420c5 100644
--- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
+++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc
@@ -93,9 +93,17 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
   }
 }
 
-void SetConfig(contrib::AnalysisConfig *cfg) {
-  cfg->prog_file = FLAGS_infer_model + "/__model__";
-  cfg->param_file = FLAGS_infer_model + "/param";
+void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) {
+  if (memory_load) {
+    std::string buffer_prog, buffer_param;
+    ReadBinaryFile(FLAGS_infer_model + "/__model__", &buffer_prog);
+    ReadBinaryFile(FLAGS_infer_model + "/param", &buffer_param);
+    cfg->SetModelBuffer(&buffer_prog[0], buffer_prog.size(), &buffer_param[0],
+                        buffer_param.size());
+  } else {
+    cfg->prog_file = FLAGS_infer_model + "/__model__";
+    cfg->param_file = FLAGS_infer_model + "/param";
+  }
   cfg->use_gpu = false;
   cfg->device = 0;
   cfg->specify_input_name = true;
@@ -114,9 +122,9 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
 }
 
 // Easy for profiling independently.
-TEST(Analyzer_Chinese_ner, profile) {
+void profile(bool memory_load = false) {
   contrib::AnalysisConfig cfg;
-  SetConfig(&cfg);
+  SetConfig(&cfg, memory_load);
   std::vector<PaddleTensor> outputs;
 
   std::vector<std::vector<PaddleTensor>> input_slots_all;
@@ -138,6 +146,12 @@ TEST(Analyzer_Chinese_ner, profile) {
   }
 }
 
+TEST(Analyzer_Chinese_ner, profile) { profile(); }
+
+TEST(Analyzer_Chinese_ner, profile_memory_load) {
+  profile(true /* memory_load */);
+}
+
 // Check the fuse status
 TEST(Analyzer_Chinese_ner, fuse_statis) {
   contrib::AnalysisConfig cfg;
diff --git a/paddle/fluid/inference/tests/api/config_printer.h b/paddle/fluid/inference/tests/api/config_printer.h
index 4231eef722..7046bce303 100644
--- a/paddle/fluid/inference/tests/api/config_printer.h
+++ b/paddle/fluid/inference/tests/api/config_printer.h
@@ -49,8 +49,6 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) {
   os << GenSpaces(num_spaces) << "device: " << config.device << "\n";
   os << GenSpaces(num_spaces)
      << "fraction_of_gpu_memory: " << config.fraction_of_gpu_memory << "\n";
-  os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n";
-  os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
   os << GenSpaces(num_spaces)
      << "specify_input_name: " << config.specify_input_name << "\n";
   os << GenSpaces(num_spaces)
@@ -65,6 +63,13 @@ std::ostream &operator<<(std::ostream &os,
   os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n";
   num_spaces++;
   os << *reinterpret_cast<const NativeConfig *>(&config);
+  if (!config.model_from_memory()) {
+    os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n";
+    os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
+  } else {
+    os << GenSpaces(num_spaces)
+       << "prog_file and param_file: load from memory \n";
+  }
   os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.enable_ir_optim
      << "\n";
   os << GenSpaces(num_spaces)
diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt
index 2104e4ac72..cfb80fe6ec 100644
--- a/paddle/fluid/inference/utils/CMakeLists.txt
+++ b/paddle/fluid/inference/utils/CMakeLists.txt
@@ -1,2 +1,7 @@
 cc_library(benchmark SRCS benchmark.cc DEPS enforce)
 cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark)
+cc_binary(visualizer SRCS visualizer.cc DEPS analysis
+    paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes)
+if(WIN32)
+  target_link_libraries(visualizer shlwapi)
+endif(WIN32)
diff --git a/paddle/fluid/inference/utils/visualizer.cc b/paddle/fluid/inference/utils/visualizer.cc
new file mode 100644
index 0000000000..040b6476fb
--- /dev/null
+++ b/paddle/fluid/inference/utils/visualizer.cc
@@ -0,0 +1,92 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/inference/utils/visualizer.h"
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+#include <fstream>
+#include <memory>
+#include "paddle/fluid/framework/ir/graph_viz_pass.h"
+#include "paddle/fluid/inference/analysis/analyzer.h"
+#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
+#include "paddle/fluid/platform/init.h"
+
+DEFINE_string(model_dir, "", "model directory");
+DEFINE_string(model_program_path, "", "model program path");
+DEFINE_string(model_params_path, "", "model params path");
+
+USE_PASS(graph_viz_pass);
+USE_PASS(graph_to_program_pass);
+
+using paddle::inference::analysis::Argument;
+
+namespace paddle {
+namespace inference {
+namespace utils {
+
+void Visualizer::SetArgument(Argument *argument) { argument_ = argument; }
+
+bool Visualizer::Run() {
+  paddle::framework::InitDevices(false);
+  paddle::inference::analysis::Analyzer().Run(argument_);
+
+  return true;
+}
+
+}  // namespace utils
+}  // namespace inference
+}  // namespace paddle
+
+// Generate a dot file describing the structure of graph.
+// To use this tool, run command: ./visualizer [options...]
+// Options:
+//     --model_dir: the directory of model
+//     --model_program_path: the path of program
+//     --model_params_path: the path of params
+int main(int argc, char *argv[]) {
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+  google::InitGoogleLogging(argv[0]);
+
+  paddle::inference::analysis::Argument argument;
+  argument.SetUseGPU(false);
+  argument.SetUseTensorRT(false);
+
+  if (FLAGS_model_dir.empty()) {
+    if (FLAGS_model_program_path.empty() || FLAGS_model_params_path.empty()) {
+      LOG(ERROR) << "Please set model_dir"
+                    " or model_program_path and model_params_path";
+      return -1;
+    } else {
+      argument.SetModelProgramPath(FLAGS_model_program_path);
+      argument.SetModelParamsPath(FLAGS_model_params_path);
+    }
+  } else {
+    argument.SetModelDir(FLAGS_model_dir);
+  }
+
+  // Only 1 pass, default filename is 0_ir_origin.dot
+  // For more details, looking for paddle::inference::analysis::IRPassManager
+  argument.SetIrAnalysisPasses({"graph_viz_pass"});
+
+  std::unique_ptr<paddle::framework::Scope> scope{
+      new paddle::framework::Scope()};
+  argument.SetScopeNotOwned(
+      const_cast<paddle::framework::Scope *>(scope.get()));
+
+  paddle::inference::utils::Visualizer visualizer;
+  visualizer.SetArgument(&argument);
+  visualizer.Run();
+
+  return 0;
+}
diff --git a/paddle/fluid/inference/utils/visualizer.h b/paddle/fluid/inference/utils/visualizer.h
new file mode 100644
index 0000000000..be532f92cf
--- /dev/null
+++ b/paddle/fluid/inference/utils/visualizer.h
@@ -0,0 +1,42 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <string>
+#include "paddle/fluid/inference/analysis/argument.h"
+
+namespace paddle {
+namespace inference {
+namespace utils {
+
+using paddle::inference::analysis::Argument;
+
+class Visualizer final {
+ public:
+  Visualizer() = default;
+  ~Visualizer() = default;
+  Visualizer(const Visualizer &) = delete;
+  Visualizer &operator=(const Visualizer &) = delete;
+
+  void SetArgument(Argument *);
+  bool Run();
+
+ private:
+  Argument *argument_;
+};
+
+}  // namespace utils
+}  // namespace inference
+}  // namespace paddle
diff --git a/paddle/fluid/memory/allocation/legacy_allocator.cc b/paddle/fluid/memory/allocation/legacy_allocator.cc
index 26e2038a53..64aa63ffe9 100644
--- a/paddle/fluid/memory/allocation/legacy_allocator.cc
+++ b/paddle/fluid/memory/allocation/legacy_allocator.cc
@@ -14,11 +14,13 @@
 
 #include "paddle/fluid/memory/allocation/legacy_allocator.h"
 #include <string>
+#include <vector>
 #include "glog/logging.h"
 #include "paddle/fluid/memory/detail/buddy_allocator.h"
 #include "paddle/fluid/memory/detail/system_allocator.h"
 #include "paddle/fluid/platform/gpu_info.h"
 #include "paddle/fluid/string/printf.h"
+#include "paddle/fluid/string/split.h"
 
 DEFINE_bool(init_allocated_mem, false,
             "It is a mistake that the values of the memory allocated by "
@@ -86,7 +88,7 @@ struct NaiveAllocator {
 
 template <>
 void *Alloc<platform::CPUPlace>(const platform::CPUPlace &place, size_t size) {
-  VLOG(1) << "Allocate " << size << " bytes on " << platform::Place(place);
+  VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place);
   void *p = GetCPUBuddyAllocator()->Alloc(size);
   if (FLAGS_init_allocated_mem) {
     memset(p, 0xEF, size);
@@ -97,7 +99,7 @@ void *Alloc<platform::CPUPlace>(const platform::CPUPlace &place, size_t size) {
 
 template <>
 void Free<platform::CPUPlace>(const platform::CPUPlace &place, void *p) {
-  VLOG(1) << "Free pointer=" << p << " on " << platform::Place(place);
+  VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place);
   GetCPUBuddyAllocator()->Free(p);
 }
 
@@ -110,19 +112,21 @@ size_t Used<platform::CPUPlace>(const platform::CPUPlace &place) {
 BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
   static std::once_flag init_flag;
   static detail::BuddyAllocator **a_arr = nullptr;
+  static std::vector<int> devices;
 
   std::call_once(init_flag, [gpu_id]() {
-    int gpu_num = platform::GetCUDADeviceCount();
-    PADDLE_ENFORCE(gpu_id < gpu_num, "gpu_id:%d should < gpu_num:%d", gpu_id,
-                   gpu_num);
+    devices = platform::GetSelectedDevices();
+    int gpu_num = devices.size();
 
     a_arr = new BuddyAllocator *[gpu_num];
-    for (int i = 0; i < gpu_num; i++) {
+    for (size_t i = 0; i < devices.size(); ++i) {
+      int dev_id = devices[i];
       a_arr[i] = nullptr;
-      platform::SetDeviceId(i);
-      a_arr[i] = new BuddyAllocator(
-          std::unique_ptr<detail::SystemAllocator>(new detail::GPUAllocator(i)),
-          platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
+      platform::SetDeviceId(dev_id);
+      a_arr[i] = new BuddyAllocator(std::unique_ptr<detail::SystemAllocator>(
+                                        new detail::GPUAllocator(dev_id)),
+                                    platform::GpuMinChunkSize(),
+                                    platform::GpuMaxChunkSize());
 
       VLOG(10) << "\n\nNOTE: each GPU device use "
                << FLAGS_fraction_of_gpu_memory_to_use * 100
@@ -134,7 +138,9 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
   });
 
   platform::SetDeviceId(gpu_id);
-  return a_arr[gpu_id];
+  auto pos = std::distance(devices.begin(),
+                           std::find(devices.begin(), devices.end(), gpu_id));
+  return a_arr[pos];
 }
 #endif
 
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index 8c8dc7026e..257bfc0a3f 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -70,7 +70,7 @@ endif()
 set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler)
 set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
 if (WITH_GPU)
-  set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv)
+  set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
 endif()
 
 # FIXME(typhoonzero): operator deps may not needed.
diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc
index 64649b1a5e..e16b6f78d1 100644
--- a/paddle/fluid/operators/activation_mkldnn_op.cc
+++ b/paddle/fluid/operators/activation_mkldnn_op.cc
@@ -100,8 +100,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
   const T *x_data = x->data<T>();
   T *y_data = y->mutable_data<T>(ctx.GetPlace());
 
-  PADDLE_ENFORCE(x->dims().size() == 2 || x->dims().size() == 4,
-                 "Input dim must be with 2 or 4");
+  PADDLE_ENFORCE(
+      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
+      "Input dim must be with 2, 3 or 4");
 
   std::vector<int> src_tz = framework::vectorize2int(x->dims());
 
diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc
index 832245371e..9c5b8604f4 100644
--- a/paddle/fluid/operators/activation_op.cc
+++ b/paddle/fluid/operators/activation_op.cc
@@ -76,8 +76,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
   }
 #endif
   return framework::OpKernelType(
-      framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
-      ctx.GetPlace(), layout, library);
+      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
+      library);
 }
 
 class ActivationOp : public framework::OperatorWithKernel {
diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h
index a0f8c5c14c..87d549678a 100644
--- a/paddle/fluid/operators/activation_op.h
+++ b/paddle/fluid/operators/activation_op.h
@@ -41,6 +41,12 @@ static std::unordered_set<std::string> InplaceOpSet = {
     "floor",   "reciprocal", "relu6", "soft_relu", "hard_sigmoid",
 };
 
+/* The following operator can be used to process SelectedRows, because the
+ * output of those operator for zero is zero too.
+ */
+static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
+    "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};
+
 static bool IsInplace(std::string op) { return InplaceOpSet.count(op); }
 
 template <typename DeviceContext, typename Functor>
@@ -50,16 +56,38 @@ class ActivationKernel
   using T = typename Functor::ELEMENT_TYPE;
 
   void Compute(const framework::ExecutionContext& context) const override {
-    auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
-                          "Cannot get input tensor X, variable name = %s",
-                          context.op().Input("X"));
-
-    auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
-                            "Cannot get output tensor Out, variable name = %s",
-                            context.op().Output("Out"));
-    Out.mutable_data<T>(context.GetPlace());
+    auto x_var = context.InputVar("X");
+    auto out_var = context.OutputVar("Out");
+    PADDLE_ENFORCE(x_var != nullptr,
+                   "Cannot get input Variable X, variable name = %s",
+                   context.op().Input("X"));
+    PADDLE_ENFORCE(out_var != nullptr,
+                   "Cannot get output Variable Out, variable name = %s",
+                   context.op().Output("Out"));
+
+    framework::Tensor X, *Out;
+
+    if (CanBeUsedBySelectedRows.count(context.op().Type())) {
+      X = detail::Ref(
+          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var),
+          "Cannot get input Tensor X, variable name = %s",
+          context.op().Input("X"));
+      Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
+          out_var);
+    } else {
+      X = detail::Ref(context.Input<framework::Tensor>("X"),
+                      "Cannot get input Tensor X, variable name = %s",
+                      context.op().Input("X"));
+      Out = context.Output<framework::Tensor>("Out");
+    }
+
+    PADDLE_ENFORCE(Out != nullptr,
+                   "Cannot get output tensor Out, variable name = %s",
+                   context.op().Output("Out"));
+
+    Out->mutable_data<T>(context.GetPlace());
     auto x = framework::EigenVector<T>::Flatten(X);
-    auto out = framework::EigenVector<T>::Flatten(Out);
+    auto out = framework::EigenVector<T>::Flatten(*Out);
     auto* place =
         context.template device_context<DeviceContext>().eigen_device();
     Functor functor;
@@ -78,14 +106,54 @@ class ActivationGradKernel
  public:
   using T = typename Functor::ELEMENT_TYPE;
   void Compute(const framework::ExecutionContext& context) const override {
-    auto* Out = context.Input<framework::Tensor>("Out");
-    auto* dOut =
-        context.Input<framework::Tensor>(framework::GradVarName("Out"));
-    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
+    auto out_var = context.InputVar("Out");
+    auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
+    auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
+    PADDLE_ENFORCE(out_var != nullptr,
+                   "Cannot get input Variable Out, variable name = %s",
+                   context.op().Input("Out"));
+    PADDLE_ENFORCE(out_grad_var != nullptr,
+                   "Cannot get input Variable %s, variable name = %s",
+                   framework::GradVarName("Out"),
+                   context.op().Input(framework::GradVarName("Out")));
+    PADDLE_ENFORCE(x_grad_var != nullptr,
+                   "Cannot get output Variable %s, variable name = %s",
+                   framework::GradVarName("X"),
+                   context.op().Output(framework::GradVarName("X")));
+
+    framework::Tensor Out, dOut, *dX;
+    if (CanBeUsedBySelectedRows.count(context.op().Type())) {
+      Out = detail::Ref(
+          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var),
+          "Cannot get input Tensor Out, variable name = %s",
+          context.op().Input("Out"));
+      dOut =
+          detail::Ref(paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
+                          *out_grad_var),
+                      "Cannot get input Tensor %s, variable name = %s",
+                      framework::GradVarName("Out"),
+                      context.op().Input(framework::GradVarName("Out")));
+      dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
+          x_grad_var);
+    } else {
+      Out = detail::Ref(context.Input<framework::Tensor>("Out"),
+                        "Cannot get input Tensor Out, variable name = %s",
+                        context.op().Input("Out"));
+      dOut = detail::Ref(
+          context.Input<framework::Tensor>(framework::GradVarName("Out")),
+          "Cannot get input Tensor %s, variable name = %s",
+          framework::GradVarName("Out"),
+          context.op().Input(framework::GradVarName("Out")));
+      dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
+    }
+    PADDLE_ENFORCE(dX != nullptr,
+                   "Cannot get output tensor %s, variable name = %s",
+                   framework::GradVarName("X"),
+                   context.op().Output(framework::GradVarName("X")));
     dX->mutable_data<T>(context.GetPlace());
 
-    auto dout = framework::EigenVector<T>::Flatten(*dOut);
-    auto out = framework::EigenVector<T>::Flatten(*Out);
+    auto dout = framework::EigenVector<T>::Flatten(dOut);
+    auto out = framework::EigenVector<T>::Flatten(Out);
     auto dx = framework::EigenVector<T>::Flatten(*dX);
     auto* place =
         context.template device_context<DeviceContext>().eigen_device();
@@ -96,8 +164,19 @@ class ActivationGradKernel
     }
     bool inplace = functor.Inplace();
     if (!inplace) {
-      auto* X = context.Input<framework::Tensor>("X");
-      auto x = framework::EigenVector<T>::Flatten(*X);
+      auto x_var = context.InputVar("X");
+      PADDLE_ENFORCE(x_var != nullptr,
+                     "Cannot get input tensor X, variable name = %s",
+                     context.op().Input("X"));
+      framework::Tensor X;
+      if (CanBeUsedBySelectedRows.count(context.op().Type())) {
+        X = detail::Ref(
+            paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var));
+      } else {
+        X = detail::Ref(context.Input<framework::Tensor>("X"));
+      }
+
+      auto x = framework::EigenVector<T>::Flatten(X);
       functor(*place, x, out, dout, dx);
     } else {
       VLOG(10) << " Inplace activation ";
diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc
index 9b943440a8..75fc59125f 100644
--- a/paddle/fluid/operators/attention_lstm_op.cc
+++ b/paddle/fluid/operators/attention_lstm_op.cc
@@ -231,10 +231,10 @@ use lstm_x_t as input and compute as standard LSTM.
 template <typename T>
 inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
   if (bias) {
-    math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
-    math::vec_relu<T, platform::jit::avx>(n, y, y);
+    math::vec_add_bias<T, platform::avx>(n, *bias, x, y);
+    math::vec_relu<T, platform::avx>(n, y, y);
   } else {
-    math::vec_relu<T, platform::jit::avx>(n, x, y);
+    math::vec_relu<T, platform::avx>(n, x, y);
   }
 }
 
@@ -245,8 +245,8 @@ inline void vec_softmax(const int n, const T* x, T* y) {
   for (int i = 1; i < n; ++i) {
     scalar = scalar < x[i] ? x[i] : scalar;
   }
-  math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y);  // sub
-  math::vec_exp<T>(n, y, y);                                    // exp
+  math::vec_add_bias<T, platform::avx>(n, -scalar, x, y);  // sub
+  math::vec_exp<T>(n, y, y);                               // exp
   // sum
   scalar = T(0);
   for (int i = 0; i < n; ++i) {
@@ -302,13 +302,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
     auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
     auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
     auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
-    if (platform::jit::MayIUse(platform::jit::avx)) {
-      math::VecActivations<T, platform::jit::avx> act_functor;
+    if (platform::MayIUse(platform::avx)) {
+      math::VecActivations<T, platform::avx> act_functor;
       act_gate = act_functor(act_gate_str);
       act_cell = act_functor(act_cell_str);
       act_cand = act_functor(act_cand_str);
     } else {
-      math::VecActivations<T, platform::jit::isa_any> act_functor;
+      math::VecActivations<T, platform::isa_any> act_functor;
       act_gate = act_functor(act_gate_str);
       act_cell = act_functor(act_cell_str);
       act_cand = act_functor(act_cand_str);
diff --git a/paddle/fluid/operators/conv_fusion_op.cu.cc b/paddle/fluid/operators/conv_fusion_op.cu.cc
index 2c09ee7394..3235ad52b9 100644
--- a/paddle/fluid/operators/conv_fusion_op.cu.cc
+++ b/paddle/fluid/operators/conv_fusion_op.cu.cc
@@ -110,11 +110,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
 
     auto x_dims = framework::vectorize(input->dims());
     auto f_dims = framework::vectorize(filter->dims());
-    if (activation == "identity") {
-      // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
-      // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
-      algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
-    } else if (!exhaustive_search) {
+    if (!exhaustive_search) {
       CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
           handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
           cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
@@ -165,18 +161,42 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
     PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
                       "workspace_size to be allocated exceeds the limit");
 
-    // ------------------- cudnn conv+bias+act forward --------------------
-    ScalingParamType<T> alpha1 = 1.0f;
-    ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
-    auto cudnn_func = [&](void* cudnn_workspace) {
-      CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
-          handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
-          filter_data, cudnn_conv_desc, algo, cudnn_workspace,
-          workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
-          cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
+    if ((activation == "identity") &&
+        (algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) &&
+        (!residual)) {
+      // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
+      // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
+      // But test in some case, the speed is slower, change to use
+      // cudnnConvolutionForward and cudnnAddTensor
+      // ------------- cudnn conv forward and bias add ---------------------
+      ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
+      auto cudnn_func = [&](void* cudnn_workspace) {
+        CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
+            handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
+            filter_data, cudnn_conv_desc, algo, cudnn_workspace,
+            workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
+      };
+      workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
+      CUDNN_ENFORCE(platform::dynload::cudnnAddTensor(
+          handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc,
           output_data));
-    };
-    workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
+    } else {
+      if (activation == "identity") {
+        algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
+      }
+      // ------------------- cudnn conv+bias+act forward --------------------
+      ScalingParamType<T> alpha1 = 1.0f;
+      ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
+      auto cudnn_func = [&](void* cudnn_workspace) {
+        CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
+            handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
+            filter_data, cudnn_conv_desc, algo, cudnn_workspace,
+            workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
+            cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
+            output_data));
+      };
+      workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
+    }
   }
 };
 #endif
diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc
index 05e268bf6a..154ff2bb20 100644
--- a/paddle/fluid/operators/conv_mkldnn_op.cc
+++ b/paddle/fluid/operators/conv_mkldnn_op.cc
@@ -28,6 +28,46 @@ using mkldnn::stream;
 using platform::to_void_cast;
 using platform::GetMKLDNNFormat;
 
+inline void GetWeightsTz(std::vector<int>& weights_tz, int groups,  // NOLINT
+                         bool is_conv3d) {
+  if (groups > 1) {
+    if (is_conv3d) {
+      int output = weights_tz[0];
+      int input = weights_tz[1];
+      int dimension = weights_tz[2];
+      int height = weights_tz[3];
+      int width = weights_tz[4];
+      weights_tz.resize(6);
+      weights_tz[0] = groups;
+      weights_tz[1] = output / groups;
+      weights_tz[2] = input;
+      weights_tz[3] = dimension;
+      weights_tz[4] = height;
+      weights_tz[5] = width;
+    } else {
+      int output = weights_tz[0];
+      int input = weights_tz[1];
+      int height = weights_tz[2];
+      int width = weights_tz[3];
+      weights_tz.resize(5);
+      weights_tz[0] = groups;
+      weights_tz[1] = output / groups;
+      weights_tz[2] = input;
+      weights_tz[3] = height;
+      weights_tz[4] = width;
+    }
+  }
+}
+
+inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
+                                               int groups, bool is_conv3d) {
+  if (is_conv3d) {
+    return (groups == 1) ? format : mkldnn::memory::format::goidhw;
+  } else {
+    return (groups == 1) ? format : mkldnn::memory::format::goihw;
+  }
+}
+
 template <typename T>
 class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
  public:
@@ -52,10 +92,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
     PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
                        filter->format() != memory::format::format_undef,
                    "Wrong layout/format set for Filter tensor");
-    PADDLE_ENFORCE(input->dims().size() == 4,
-                   "Input must be with 4 dimensions, i.e. NCHW");
-    PADDLE_ENFORCE(filter->dims().size() == 4,
-                   "Filter must be with 4 dimensions, i.e. OIHW");
+    PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
+                   "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
+    PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
+                   "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
     if (bias) {
       PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
                          bias->format() != memory::format::format_undef,
@@ -71,9 +111,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
     bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
     int groups = ctx.Attr<int>("groups");
 
+    bool is_conv3d = strides.size() == 3U;
     // TODO(tpatejko): add support for dilation
     PADDLE_ENFORCE(
-        dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
+        is_conv3d
+            ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
+                  dilations[2] == 1
+            : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
         "dilation in convolution is not implemented yet");
 
     const T* input_data = input->data<T>();
@@ -83,18 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
     std::vector<int> weights_tz =
         paddle::framework::vectorize2int(filter->dims());
     int g = std::max(groups, 1);
-    if (g > 1) {
-      int o = weights_tz[0];
-      int i = weights_tz[1];
-      int h = weights_tz[2];
-      int w = weights_tz[3];
-      weights_tz.resize(5);
-      weights_tz[0] = g;
-      weights_tz[1] = o / g;
-      weights_tz[2] = i;
-      weights_tz[3] = h;
-      weights_tz[4] = w;
-    }
+    GetWeightsTz(weights_tz, g, is_conv3d);
     std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
 
     // Get unique name for storing MKLDNN primitives
@@ -105,11 +138,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 
     std::vector<primitive> pipeline;
 
+    auto src_format = input->format();
+    mkldnn::memory::format weights_format =
+        GetWeightsFormat(filter->format(), g, is_conv3d);
+
     auto user_src_md = platform::MKLDNNMemDesc(
-        {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
+        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
     auto user_weights_md = platform::MKLDNNMemDesc(
-        {weights_tz}, platform::MKLDNNGetDataType<T>(),
-        (g == 1) ? filter->format() : mkldnn::memory::format::goihw);
+        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
 
     /* create memory descriptor for convolution without specified format
      * ('any') which lets a primitive (convolution in this case) choose
@@ -119,10 +155,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
     auto chosen_memory_format =
         platform::data_format_to_memory_format(data_format);
 
+    if (is_conv3d) {
+      chosen_memory_format =
+          platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
+    }
+    weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
+
     auto src_md = platform::MKLDNNMemDesc(
         src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
     auto weights_md = platform::MKLDNNMemDesc(
-        weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
+        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
     std::vector<int> bias_tz;  // TODO(mgallus): avoid empty vector creation.
                                // Currently used whenever bias is != nullptr.
     auto dst_md = platform::MKLDNNMemDesc(
@@ -263,8 +305,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
                        const mkldnn::engine& engine, const bool fuse_relu,
                        const bool fuse_residual_conn,
                        mkldnn::prop_kind fwd_prop_kind) const {
-    memory::dims stride_dims = {strides[0], strides[1]};
-    memory::dims padding_dims = {paddings[0], paddings[1]};
+    memory::dims stride_dims = strides;
+    memory::dims padding_dims = paddings;
 
     auto conv_desc = mkldnn::convolution_forward::desc(
         fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
@@ -288,8 +330,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
                        const mkldnn::engine& engine, const bool fuse_relu,
                        const bool fuse_residual_conn,
                        mkldnn::prop_kind fwd_prop_kind) const {
-    memory::dims stride_dims = {strides[0], strides[1]};
-    memory::dims padding_dims = {paddings[0], paddings[1]};
+    memory::dims stride_dims = strides;
+    memory::dims padding_dims = paddings;
 
     auto conv_desc = mkldnn::convolution_forward::desc(
         fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
@@ -349,6 +391,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
     std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
     int groups = ctx.Attr<int>("groups");
 
+    bool is_conv3d = strides.size() == 3U;
     const T* input_data = input->data<T>();
     const T* filter_data = filter->data<T>();
     const T* output_grad_data = output_grad->data<T>();
@@ -358,8 +401,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
     std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
     std::vector<int> weights_tz =
         paddle::framework::vectorize2int(filter->dims());
+    int g = std::max(groups, 1);
+    GetWeightsTz(weights_tz, g, is_conv3d);
     std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
 
+    auto src_format = input->format();
+    mkldnn::memory::format weights_format =
+        GetWeightsFormat(filter->format(), g, is_conv3d);
+
     // Get an unique name from "argument" name of "Output" variable
     // as well as attributes of primitive to be created
     // This name will be used as key when saving info into device context
@@ -372,9 +421,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 
     // Create user memory descriptors
     auto user_src_md = platform::MKLDNNMemDesc(
-        {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
+        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
     auto user_weights_md = platform::MKLDNNMemDesc(
-        {weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
+        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
     auto user_diff_dst_md = platform::MKLDNNMemDesc(
         {dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
 
@@ -386,14 +435,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
     auto chosen_memory_format =
         platform::data_format_to_memory_format(data_format);
 
+    if (is_conv3d) {
+      chosen_memory_format =
+          platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
+    }
+    weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
+
     auto src_md = platform::MKLDNNMemDesc(
         src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
     auto diff_src_md = platform::MKLDNNMemDesc(
         src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
     auto weights_md = platform::MKLDNNMemDesc(
-        weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
+        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
     auto diff_weights_md = platform::MKLDNNMemDesc(
-        weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
+        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
     auto diff_dst_md = platform::MKLDNNMemDesc(
         dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
 
@@ -491,8 +546,22 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 
 namespace ops = paddle::operators;
 
-REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
-                   ops::ConvMKLDNNOpKernel<float>);
-
-REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
-                   ops::ConvMKLDNNGradOpKernel<float>);
+REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
+                                    ::paddle::platform::CPUPlace, FP32,
+                                    ops::kConvMKLDNNFP32,
+                                    ops::ConvMKLDNNOpKernel<float>);
+
+REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
+                                    ::paddle::platform::CPUPlace, FP32,
+                                    ops::kConvMKLDNNFP32,
+                                    ops::ConvMKLDNNGradOpKernel<float>);
+
+REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
+                                    ::paddle::platform::CPUPlace, FP32,
+                                    ops::kConvMKLDNNFP32,
+                                    ops::ConvMKLDNNOpKernel<float>);
+
+REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
+                                    ::paddle::platform::CPUPlace, FP32,
+                                    ops::kConvMKLDNNFP32,
+                                    ops::ConvMKLDNNGradOpKernel<float>);
diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc
index 342525be49..d7b8766288 100644
--- a/paddle/fluid/operators/conv_op.cc
+++ b/paddle/fluid/operators/conv_op.cc
@@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
 
 framework::OpKernelType ConvOp::GetExpectedKernelType(
     const framework::ExecutionContext& ctx) const {
+  int customized_type_value =
+      framework::OpKernelType::kDefaultCustomizedTypeValue;
   framework::LibraryType library{framework::LibraryType::kPlain};
   // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
   std::string data_format = ctx.Attr<std::string>("data_format");
@@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
       platform::CanMKLDNNBeUsed(ctx)) {
     library = framework::LibraryType::kMKLDNN;
     layout = framework::DataLayout::kMKLDNN;
+    customized_type_value = kConvMKLDNNFP32;
   }
 #endif
 
@@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
   }
 
   return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
-                                 library);
+                                 library, customized_type_value);
 }
 
 void Conv2DOpMaker::Make() {
@@ -131,14 +134,14 @@ void Conv2DOpMaker::Make() {
            "The format of output tensor is X (one-dimensional) of size equal"
            "to the number of output channels. Only used with MKL-DNN.")
       .AsDispensable();
-  AddOutput("Output",
-            "(Tensor) The output tensor of convolution operator. "
-            "The format of output tensor is also NCHW.");
   AddInput("ResidualData",
            "(Tensor) Tensor with residual data "
            "to which convolution output will be added."
            "Used with fuse_residual_connection fusion.")
       .AsDispensable();
+  AddOutput("Output",
+            "(Tensor) The output tensor of convolution operator. "
+            "The format of output tensor is also NCHW.");
   AddAttr<std::vector<int>>("strides",
                             "(vector<int> default:{1, 1}), the "
                             "strides(h_stride, w_stride) of "
@@ -229,6 +232,10 @@ $$
 }
 
 void Conv3DOpMaker::Make() {
+  AddAttr<bool>("is_test",
+                "(bool, default false) Set to true for inference only, false "
+                "for training. Some layers may run faster when this is true.")
+      .SetDefault(false);
   AddInput(
       "Input",
       "(Tensor) The input tensor of convolution operator. "
@@ -244,6 +251,11 @@ void Conv3DOpMaker::Make() {
            "is the width of the filter."
            "If the groups attribute is greater than 1, C equals the number of "
            "input image channels divided by the groups.");
+  AddInput("ResidualData",
+           "(Tensor) Tensor with residual data "
+           "to which convolution output will be added."
+           "Used with fuse_residual_connection fusion.")
+      .AsDispensable();
   AddOutput("Output",
             "(Tensor) The output tensor of convolution operator."
             "The format of output tensor is also NCDHW.");
@@ -277,6 +289,13 @@ void Conv3DOpMaker::Make() {
   AddAttr<bool>("use_mkldnn",
                 "(bool, default false) Only used in mkldnn kernel")
       .SetDefault(false);
+  AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
+      .SetDefault(false);
+  AddAttr<bool>("fuse_residual_connection",
+                "(bool, default false) Only used in mkldnn kernel. Used "
+                "whenever convolution output is as an input to residual "
+                "connection.")
+      .SetDefault(false);
   AddAttr<std::string>(
       "data_format",
       "(string, default NCHW) Only used in "
@@ -342,6 +361,8 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
 
 framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
     const framework::ExecutionContext& ctx) const {
+  int customized_type_value =
+      framework::OpKernelType::kDefaultCustomizedTypeValue;
   framework::LibraryType library_{framework::LibraryType::kPlain};
   // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
   std::string data_format = ctx.Attr<std::string>("data_format");
@@ -357,12 +378,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
       platform::CanMKLDNNBeUsed(ctx)) {
     library_ = framework::LibraryType::kMKLDNN;
     layout_ = framework::DataLayout::kMKLDNN;
+    customized_type_value = kConvMKLDNNFP32;
   }
 #endif
 
   return framework::OpKernelType(
       framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
-      layout_, library_);
+      layout_, library_, customized_type_value);
 }
 
 }  // namespace operators
diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h
index e69814001e..249f308c13 100644
--- a/paddle/fluid/operators/conv_op.h
+++ b/paddle/fluid/operators/conv_op.h
@@ -27,6 +27,8 @@ namespace paddle {
 namespace operators {
 
 using Tensor = framework::Tensor;
+constexpr int kConvMKLDNNFP32 = 1;
+constexpr int kConvMKLDNNINT8 = 2;
 
 // Base convolution operator definations for other conv
 // like operators to reuse the implementation.
diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc
index e01070c7b8..dd64cc327f 100644
--- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc
+++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc
@@ -177,11 +177,19 @@ struct CudnnRNNCache {
         seed_));
 
     CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
+
+#if CUDNN_VERSION >= 6000
     CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
         handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
         CUDNN_LINEAR_INPUT,
         is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
         CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
+#else
+    CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor(
+        rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
+        is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
+        CUDNN_DATA_FLOAT));
+#endif
 
     CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
     CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt
index 36979de68f..101dbe9c89 100644
--- a/paddle/fluid/operators/distributed/CMakeLists.txt
+++ b/paddle/fluid/operators/distributed/CMakeLists.txt
@@ -13,16 +13,26 @@ set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor
 
 if(WITH_GRPC)
   grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
-        request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
+        request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
       PROTO send_recv.proto 
-      DEPS lod_tensor selected_rows memory)
+      DEPS lod_tensor selected_rows_functor memory)
 
   set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
+
   cc_test(grpc_serde_test SRCS grpc_serde_test.cc 
     DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
+
   cc_test(rpc_server_test SRCS rpc_server_test.cc
     DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor  proto_desc lookup_sparse_table_op SERIAL)
+
   cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
+
+  if(WITH_GPU)
+  cc_test(collective_server_test SRCS collective_server_test.cc 
+      DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
+      selected_rows_functor  scope math_function SERIAL)
+  endif()
+
   cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
 else()
   set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
diff --git a/paddle/fluid/operators/distributed/collective_client.cc b/paddle/fluid/operators/distributed/collective_client.cc
new file mode 100644
index 0000000000..6d3f534311
--- /dev/null
+++ b/paddle/fluid/operators/distributed/collective_client.cc
@@ -0,0 +1,59 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <condition_variable>  // NOLINT
+#include <string>
+#include "gflags/gflags.h"
+
+#include "paddle/fluid/operators/distributed/collective_client.h"
+
+DECLARE_int32(rpc_deadline);
+
+namespace paddle {
+namespace operators {
+namespace distributed {
+std::once_flag CollectiveClient::init_flag_;
+std::unique_ptr<CollectiveClient> CollectiveClient::client_(nullptr);
+
+bool CollectiveClient::Gather(const std::vector<RemoteVar>& remote_vars,
+                              std::vector<const framework::SelectedRows*>* dst,
+                              const platform::DeviceContext& ctx,
+                              framework::Scope* scope, int64_t time_out) {
+  for (auto r : remote_vars) {
+    VLOG(50) << "begin gather from ep:" << r.String();
+    scope->Var(r.var_name_)->GetMutable<framework::SelectedRows>();
+    VarHandlePtr ptr = rpc_client_->AsyncGetMonomerVariable(
+        r.ep_, ctx, *scope, r.var_name_, time_out);
+  }
+
+  rpc_client_->Wait();
+
+  for (auto r : remote_vars) {
+    auto select_rows =
+        scope->FindVar(r.var_name_)->GetMutable<framework::SelectedRows>();
+    dst->push_back(select_rows);
+
+    VLOG(4) << "gather from ep:" << r.String()
+            << ", select_rows:" << GetSelectedRowsInfo(*select_rows);
+
+    rpc_client_->AsyncGetMonomerBarrier(r.ep_, r.var_name_);
+  }
+
+  rpc_client_->Wait();
+  return true;
+}
+
+}  // namespace distributed
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/distributed/collective_client.h b/paddle/fluid/operators/distributed/collective_client.h
new file mode 100644
index 0000000000..53b03c531a
--- /dev/null
+++ b/paddle/fluid/operators/distributed/collective_client.h
@@ -0,0 +1,93 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <condition_variable>  // NOLINT
+#include <string>
+#include <vector>
+#include "gflags/gflags.h"
+
+#include "paddle/fluid/framework/data_type.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/operators/detail/macros.h"
+#include "paddle/fluid/operators/distributed/request_handler.h"
+
+DECLARE_int32(rpc_deadline);
+
+namespace paddle {
+namespace operators {
+namespace distributed {
+
+inline std::string GetSelectedRowsInfo(const framework::SelectedRows& slr) {
+  std::stringstream ss;
+  ss << ", height:" << slr.height() << ", rows:[";
+  for (unsigned int i = 0; i < slr.rows().size(); i++) {
+    if (i != slr.rows().size() - 1) {
+      ss << slr.rows()[i] << ",";
+    } else {
+      ss << slr.rows()[i];
+    }
+  }
+  ss << "], dims:" << slr.value().dims();
+  return ss.str();
+}
+
+struct RemoteVar {
+  std::string ep_;
+  std::string var_name_;
+  int trainer_id_{0};
+
+  std::string String() {
+    std::stringstream ss;
+    ss << "ep:" << ep_ << ", var_name:" << var_name_
+       << ", trainer_id:" << trainer_id_;
+
+    return ss.str();
+  }
+};
+
+class CollectiveClient {
+ public:
+  CollectiveClient() {
+    rpc_client_.reset(new RPCCLIENT_T());
+    rpc_client_->InitImpl();
+  }
+  virtual ~CollectiveClient() {}
+
+  // note this function will retain the rank order.
+  bool Gather(const std::vector<RemoteVar>& remote_vars,
+              std::vector<const framework::SelectedRows*>* dst,
+              const platform::DeviceContext& ctx, framework::Scope* scope,
+              int64_t time_out = FLAGS_rpc_deadline);
+
+  static CollectiveClient* GetInstance() {
+    std::call_once(init_flag_, [&]() {
+      if (client_.get() == nullptr) {
+        client_.reset(new CollectiveClient());
+      }
+    });
+    return client_.get();
+  }
+
+ private:
+  std::unique_ptr<RPCClient> rpc_client_;
+
+  static std::once_flag init_flag_;
+  static std::unique_ptr<CollectiveClient> client_;
+};
+}  // namespace distributed
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/distributed/collective_server.cc b/paddle/fluid/operators/distributed/collective_server.cc
new file mode 100644
index 0000000000..c95652400c
--- /dev/null
+++ b/paddle/fluid/operators/distributed/collective_server.cc
@@ -0,0 +1,74 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <stdio.h>  // for removing the port file
+#include <csignal>
+#include <cstdlib>
+#include <fstream>
+#include <thread>  // NOLINT
+#include <vector>
+
+#include "paddle/fluid/operators/distributed/collective_server.h"
+
+DEFINE_int32(collective_get_thread_num, 5, "number of threads for rpc get");
+
+namespace paddle {
+namespace operators {
+namespace distributed {
+
+std::once_flag CollectiveServer::init_flag_;
+std::shared_ptr<CollectiveServer> CollectiveServer::collective_server_(nullptr);
+
+CollectiveServer::CollectiveServer(const std::string& end_point, int fan_in) {
+  VLOG(1) << "Create colllective server:" << end_point << ", fan_in:" << fan_in;
+  rpc_server_.reset(new RPCSERVER_T(end_point, fan_in));
+}
+
+void CollectiveServer::Stop() {
+  rpc_server_->ShutDown();
+  server_thread_->join();
+  loop_thread_->join();
+}
+
+void CollectiveServer::StartServer() {
+  get_monomer_handler_.reset(new GetMonomerHandler());
+  get_monomer_handler_->SetRPCServer(rpc_server_.get());
+
+  get_barrier_handler_.reset(new GetMonomerBarrierHandler());
+  get_barrier_handler_->SetRPCServer(rpc_server_.get());
+
+  rpc_server_->RegisterRPC(distributed::kRequestGetMonomerVariable,
+                           get_monomer_handler_.get(),
+                           FLAGS_collective_get_thread_num);
+  rpc_server_->RegisterRPC(distributed::kRequestGetMonomerBarrier,
+                           get_barrier_handler_.get(), 1);
+
+  server_thread_.reset(new std::thread([&]() { rpc_server_->StartServer(); }));
+  rpc_server_->WaitServerReady();
+
+  loop_thread_.reset(new std::thread([&]() {
+    while (true) {
+      if (rpc_server_->IsExit()) {
+        LOG(WARNING) << "get exit!rpc_processor break!";
+        break;
+      }
+      sleep(1);
+    }
+    VLOG(1) << "CollectiveServer loop_thread end";
+  }));
+}
+
+};  // namespace distributed
+};  // namespace operators
+};  // namespace paddle
diff --git a/paddle/fluid/operators/distributed/collective_server.h b/paddle/fluid/operators/distributed/collective_server.h
new file mode 100644
index 0000000000..a23dc18b4d
--- /dev/null
+++ b/paddle/fluid/operators/distributed/collective_server.h
@@ -0,0 +1,110 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include <map>
+#include <set>
+#include <string>
+#include <thread>  // NOLINT
+#include <utility>
+#include <vector>
+
+#include "gflags/gflags.h"
+
+#include "paddle/fluid/operators/detail/macros.h"
+#include "paddle/fluid/operators/distributed/request_handler.h"
+#include "paddle/fluid/operators/distributed/request_handler_impl.h"
+#include "paddle/fluid/operators/distributed/rpc_server.h"
+
+namespace paddle {
+namespace operators {
+namespace distributed {
+
+class CollectiveServer;
+
+class GetMonomerHandler final : public RequestHandler {
+ public:
+  GetMonomerHandler() : RequestHandler(true) {}
+  virtual ~GetMonomerHandler() {}
+  bool Handle(const std::string& var_name, framework::Scope* scope,
+              framework::Variable* var, framework::Variable** outvar,
+              const int trainer_id, const std::string& out_var_name = "",
+              const std::string& table_name = "") override {
+    VLOG(50) << "GetMonomerHandler recv " << var_name;
+
+    *outvar = scope->FindVar(var_name);
+    PADDLE_ENFORCE(outvar != nullptr, "%s not found", var_name);
+
+    return true;
+  }
+};
+
+class GetMonomerBarrierHandler final : public RequestHandler {
+ public:
+  GetMonomerBarrierHandler() : RequestHandler(true) {}
+  virtual ~GetMonomerBarrierHandler() {}
+  bool Handle(const std::string& var_name, framework::Scope* scope,
+              framework::Variable* var, framework::Variable** outvar,
+              const int trainer_id, const std::string& out_var_name = "",
+              const std::string& table_name = "") override {
+    VLOG(50) << "GetMonomerHandler recv " << var_name;
+
+    rpc_server_->IncreaseVarBarrier(var_name);
+
+    return true;
+  }
+};
+
+class CollectiveServer final {
+ public:
+  explicit CollectiveServer(const std::string& end_point, int fan_in);
+
+  virtual ~CollectiveServer() {}
+
+  void StartServer();
+
+  static CollectiveServer* GetInstance(const std::string& end_point,
+                                       int fan_in) {
+    std::call_once(init_flag_, [&]() {
+      if (collective_server_.get() == nullptr) {
+        collective_server_.reset(new CollectiveServer(end_point, fan_in));
+        collective_server_->StartServer();
+      }
+    });
+
+    return collective_server_.get();
+  }
+
+  std::shared_ptr<RPCServer> GetRPCServer() { return rpc_server_; }
+
+  void Stop();
+
+ private:
+  std::unique_ptr<GetMonomerHandler> get_monomer_handler_;
+  std::unique_ptr<GetMonomerBarrierHandler> get_barrier_handler_;
+
+  std::shared_ptr<distributed::RPCServer> rpc_server_;
+  std::shared_ptr<std::thread> server_thread_;
+  std::shared_ptr<std::thread> loop_thread_;
+
+  bool ready_{false};
+
+  static std::once_flag init_flag_;
+  static std::shared_ptr<CollectiveServer> collective_server_;
+};
+
+};  // namespace distributed
+};  // namespace operators
+};  // namespace paddle
diff --git a/paddle/fluid/operators/distributed/collective_server_test.cc b/paddle/fluid/operators/distributed/collective_server_test.cc
new file mode 100644
index 0000000000..0a9c69e393
--- /dev/null
+++ b/paddle/fluid/operators/distributed/collective_server_test.cc
@@ -0,0 +1,115 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <unistd.h>
+#include <string>
+#include <thread>  // NOLINT
+
+#include "gtest/gtest.h"
+#include "paddle/fluid/framework/block_desc.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/operator.h"
+
+#include "paddle/fluid/operators/detail/macros.h"
+#include "paddle/fluid/operators/distributed/collective_client.h"
+#include "paddle/fluid/operators/distributed/collective_server.h"
+#include "paddle/fluid/operators/distributed/request_handler_impl.h"
+#include "paddle/fluid/operators/math/math_function.h"
+
+namespace framework = paddle::framework;
+namespace platform = paddle::platform;
+namespace distributed = paddle::operators::distributed;
+
+std::unique_ptr<distributed::CollectiveServer> StartServer(
+    const std::string& ep, int fan_in, framework::Scope* scope,
+    platform::DeviceContext* dev_ctx) {
+  distributed::CollectiveServer* server =
+      distributed::CollectiveServer::GetInstance(ep, fan_in);
+
+  auto rpc_server = server->GetRPCServer();
+  rpc_server->RegisterVar("var1", distributed::kRequestGetMonomerVariable,
+                          scope, dev_ctx);
+
+  std::cout << "StartServer return" << std::endl;
+  return std::unique_ptr<distributed::CollectiveServer>(server);
+}
+
+std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
+  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+  auto& ctx = *pool.Get(place);
+
+  framework::Scope* scope = new framework::Scope();
+  framework::Variable* var = scope->Var("var1");
+  auto* slr = var->GetMutable<framework::SelectedRows>();
+  slr->set_height(1000);
+
+  auto* tensor = slr->mutable_value();
+  auto* rows = slr->mutable_rows();
+
+  tensor->Resize(framework::make_ddim({3, 5}));
+  tensor->mutable_data<float>(place);
+
+  paddle::operators::math::set_constant(ctx, tensor, 32.7);
+  for (int i = 0; i < 3; ++i) rows->push_back(i);
+
+  std::cout << "src:" << distributed::GetSelectedRowsInfo(*slr);
+
+  return std::unique_ptr<framework::Scope>(scope);
+}
+
+void Gather(const std::vector<distributed::RemoteVar>& vars,
+            platform::DeviceContext* dev_ctx) {
+  distributed::CollectiveClient* client =
+      distributed::CollectiveClient::GetInstance();
+
+  framework::Scope* scope = new framework::Scope();
+  framework::Variable* var = scope->Var("var1");
+  var->GetMutable<framework::SelectedRows>();
+
+  std::vector<const framework::SelectedRows*> dst;
+  client->Gather(vars, &dst, *dev_ctx, scope);
+  std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]);
+}
+
+TEST(PREFETCH, GPU) {
+  platform::CUDAPlace place;
+  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+  auto& ctx = *pool.Get(place);
+
+  std::string ep = "127.0.0.1:7164";
+  auto scope = GenerateVars(place);
+
+  auto* v1 = scope->FindVar("var1");
+  std::cout << "var1:" << v1 << std::endl;
+
+  auto server = StartServer(ep, 2, scope.get(), &ctx);
+  auto rpc_server = server->GetRPCServer();
+
+  distributed::RemoteVar var;
+  var.ep_ = ep;
+  var.var_name_ = "var1";
+  var.trainer_id_ = 0;
+
+  std::vector<distributed::RemoteVar> vars{var};
+  Gather(vars, &ctx);
+  Gather(vars, &ctx);
+
+  std::cout << "begin WaitVarBarrier" << std::endl;
+  rpc_server->WaitVarBarrier("var1");
+  rpc_server->ClearRegisteredVars();
+  server->Stop();
+
+  scope.release();
+  server.release();
+}
diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc
index d7f3ea86af..857214aa21 100644
--- a/paddle/fluid/operators/distributed/grpc_client.cc
+++ b/paddle/fluid/operators/distributed/grpc_client.cc
@@ -28,11 +28,11 @@ namespace paddle {
 namespace operators {
 namespace distributed {
 
-void GRPCClient::InitImpl() { InitEventLoop(); }
-
-void GRPCClient::InitEventLoop() {
+void GRPCClient::InitImpl() {
   // start the client process thread
   // TODO(wuyi): can make this in a threadpool
+  PADDLE_ENFORCE(client_thread_ == nullptr,
+                 "please not re init proceed thread");
   client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
 }
 
@@ -106,6 +106,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
 
 void ProcGetResponse(const VarHandle& var_h,
                      const ::grpc::ByteBuffer& ret_msg) {
+  VLOG(100) << "ProcGetResponse";
   framework::Variable* outvar = nullptr;
   // get response's trainer_id is not used
   int trainer_id;
@@ -126,6 +127,24 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
+  return _AsyncGetVar(ep, ctx, scope, var_name,
+                      "/sendrecv.SendRecvService/GetVariable", time_out);
+}
+
+VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
+    const std::string& ep, const platform::DeviceContext& ctx,
+    const framework::Scope& scope, const std::string& var_name,
+    int64_t time_out) {
+  return _AsyncGetVar(ep, ctx, scope, var_name,
+                      "/sendrecv.SendRecvService/GetMonomerVariable", time_out);
+}
+
+VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
+                                      const platform::DeviceContext& ctx,
+                                      const framework::Scope& scope,
+                                      const std::string& var_name,
+                                      const std::string& rpc_path,
+                                      int64_t time_out) {
   const platform::DeviceContext* p_ctx = &ctx;
   const std::string ep_val = ep;
   const std::string var_name_val = var_name;
@@ -136,7 +155,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
   VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
   s->Prepare(h, time_out);
 
-  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
+  framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
     // prepare input
     sendrecv::VariableMessage req;
     req.set_varname(var_name_val);
@@ -151,8 +170,8 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
 
     platform::RecordRPCEvent record_event(method, p_ctx);
 
-    auto call = s->stub_g_.PrepareUnaryCall(
-        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
+    auto call =
+        s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
     call->StartCall();
     call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
 
@@ -268,6 +287,34 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
   return h;
 }
 
+VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
+                                                const std::string& var_name,
+                                                int64_t time_out) {
+  const auto ch = GetChannel(ep);
+  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
+  const std::string method = "SendMonomerFetchBarrierRPC";
+  VarHandlePtr h(
+      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
+  s->Prepare(h, time_out);
+
+  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
+
+  sendrecv::VariableMessage req;
+  req.set_varname(var_name);
+
+  platform::RecordRPCEvent record_event(method, nullptr);
+
+  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
+  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
+  req_count_++;
+
+  if (UNLIKELY(platform::IsProfileEnabled())) {
+    h->Wait();
+  }
+
+  return h;
+}
+
 VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                            int64_t time_out) {
   const auto ch = GetChannel(ep);
diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h
index a31a465645..01bf46cc31 100644
--- a/paddle/fluid/operators/distributed/grpc_client.h
+++ b/paddle/fluid/operators/distributed/grpc_client.h
@@ -189,6 +189,11 @@ class GRPCClient : public RPCClient {
                            const std::string& var_name,
                            int64_t time_out = FLAGS_rpc_deadline) override;
 
+  VarHandlePtr AsyncGetMonomerVariable(
+      const std::string& ep, const platform::DeviceContext& ctx,
+      const framework::Scope& scope, const std::string& var_name,
+      int64_t time_out = FLAGS_rpc_deadline) override;
+
   VarHandlePtr AsyncPrefetchVar(const std::string& ep,
                                 const platform::DeviceContext& ctx,
                                 const framework::Scope& scope,
@@ -200,8 +205,12 @@ class GRPCClient : public RPCClient {
   VarHandlePtr AsyncSendBatchBarrier(
       const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
 
-  VarHandlePtr AsyncSendFetchBarrier(
-      const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
+  VarHandlePtr AsyncSendFetchBarrier(const std::string& ep,
+                                     int64_t time_out) override;
+
+  VarHandlePtr AsyncGetMonomerBarrier(
+      const std::string& ep, const std::string& var_name,
+      int64_t time_out = FLAGS_rpc_deadline) override;
 
   VarHandlePtr AsyncCheckpointNotify(
       const std::string& ep, const std::string& dir,
@@ -214,21 +223,22 @@ class GRPCClient : public RPCClient {
 
   void SendComplete() override;
 
- protected:
   void InitImpl() override;
 
  private:
-  // InitEventLoop should only be called by Init()
-  void InitEventLoop();
-
   void Proceed();
 
   std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
+  VarHandlePtr _AsyncGetVar(const std::string& ep,
+                            const platform::DeviceContext& ctx,
+                            const framework::Scope& scope,
+                            const std::string& var_name, const std::string& rpc,
+                            int64_t time_out);
 
  private:
   grpc::CompletionQueue cq_;
   std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
-  std::unique_ptr<std::thread> client_thread_;
+  std::unique_ptr<std::thread> client_thread_{nullptr};
 
   // mutex for Wait client sync
   std::mutex sync_mutex_;
diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc
index d9200c98b2..c3974138f4 100644
--- a/paddle/fluid/operators/distributed/grpc_server.cc
+++ b/paddle/fluid/operators/distributed/grpc_server.cc
@@ -158,6 +158,98 @@ class RequestGet final : public RequestBase {
   ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
 };
 
+class RequestGetMonomerVariable final : public RequestBase {
+ public:
+  explicit RequestGetMonomerVariable(GrpcService::AsyncService* service,
+                                     ::grpc::ServerCompletionQueue* cq,
+                                     RequestHandler* request_handler,
+                                     int req_id, RPCServer* rpc_server)
+      : RequestBase(service, cq, request_handler, req_id),
+        responder_(&ctx_),
+        rpc_server_(rpc_server) {
+    auto method_id =
+        static_cast<int>(distributed::GrpcMethod::kGetMonomerVariable);
+    service_->RequestAsyncUnary(
+        method_id, &ctx_, &request_, &responder_, cq_, cq_,
+        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
+  }
+
+  virtual ~RequestGetMonomerVariable() {}
+
+  std::string GetReqName() override { return request_.varname(); }
+
+  void Process() override {
+    // proc request.
+    std::string varname = request_.varname();
+
+    rpc_server_->WaitVarCond(varname);
+    MonomerHandle h = rpc_server_->GetMonomer(varname);
+
+    auto scope = h.scope_;
+    auto invar = scope->FindVar(varname);
+    framework::Variable* outvar = nullptr;
+
+    request_handler_->Handle(varname, scope, invar, &outvar,
+                             request_.trainer_id());
+
+    if (outvar) {
+      SerializeToByteBuffer(varname, outvar, *h.dev_ctx_, &reply_);
+    }
+    Finish(reply_, &responder_);
+  }
+
+ protected:
+  sendrecv::VariableMessage request_;
+  ::grpc::ByteBuffer reply_;
+  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
+  RPCServer* rpc_server_{nullptr};
+};
+
+class RequestGetMonomerBarrier final : public RequestBase {
+ public:
+  explicit RequestGetMonomerBarrier(GrpcService::AsyncService* service,
+                                    ::grpc::ServerCompletionQueue* cq,
+                                    RequestHandler* request_handler, int req_id,
+                                    RPCServer* rpc_server)
+      : RequestBase(service, cq, request_handler, req_id),
+        responder_(&ctx_),
+        rpc_server_(rpc_server) {
+    auto method_id =
+        static_cast<int>(distributed::GrpcMethod::kGetMonomerBarrier);
+    service_->RequestAsyncUnary(
+        method_id, &ctx_, &request_, &responder_, cq_, cq_,
+        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
+  }
+
+  virtual ~RequestGetMonomerBarrier() {}
+
+  std::string GetReqName() override { return request_.varname(); }
+
+  void Process() override {
+    // proc request.
+    std::string varname = request_.varname();
+    VLOG(4) << "RequestGetMonomerBarrier " << varname;
+
+    rpc_server_->WaitVarCond(varname);
+    MonomerHandle h = rpc_server_->GetMonomer(varname);
+
+    framework::Scope* scope = nullptr;
+    framework::Variable* invar = nullptr;
+    framework::Variable* outvar = nullptr;
+
+    request_handler_->Handle(varname, scope, invar, &outvar,
+                             request_.trainer_id());
+
+    Finish(reply_, &responder_);
+  }
+
+ protected:
+  sendrecv::VariableMessage request_;
+  sendrecv::VoidMessage reply_;
+  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
+  RPCServer* rpc_server_{nullptr};
+};
+
 class RequestPrefetch final : public RequestBase {
  public:
   explicit RequestPrefetch(GrpcService::AsyncService* service,
@@ -249,7 +341,7 @@ class RequestCheckpointNotify final : public RequestBase {
 };
 
 void AsyncGRPCServer::WaitServerReady() {
-  VLOG(4) << "AsyncGRPCServer is wait server ready";
+  VLOG(4) << "AsyncGRPCServer is waiting server ready";
   std::unique_lock<std::mutex> lock(this->mutex_ready_);
   condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
   VLOG(4) << "AsyncGRPCServer WaitSeverReady";
@@ -368,6 +460,12 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
     b = new RequestSend(&service_, cq.get(), handler, req_id);
   } else if (rpc_name == kRequestGet) {
     b = new RequestGet(&service_, cq.get(), handler, req_id);
+  } else if (rpc_name == kRequestGetMonomerVariable) {
+    b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
+                                      this);
+  } else if (rpc_name == kRequestGetMonomerBarrier) {
+    b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id,
+                                     this);
   } else if (rpc_name == kRequestPrefetch) {
     b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
   } else if (rpc_name == kRequestCheckpoint) {
@@ -378,7 +476,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
 
   reqs[req_id] = b;
 
-  VLOG(4) << "Create RequestSend status:" << b->Status();
+  VLOG(4) << "TryToRegisterNewOne status:" << b->Status();
 }
 
 void AsyncGRPCServer::HandleRequest(
diff --git a/paddle/fluid/operators/distributed/grpc_service.h b/paddle/fluid/operators/distributed/grpc_service.h
index 9ae9a31a00..537429b5fe 100644
--- a/paddle/fluid/operators/distributed/grpc_service.h
+++ b/paddle/fluid/operators/distributed/grpc_service.h
@@ -81,10 +81,12 @@ enum class GrpcMethod {
   kGetVariable,
   kPrefetchVariable,
   kCheckpointNotify,
+  kGetMonomerVariable,
+  kGetMonomerBarrier,
 };
 
 static const int kGrpcNumMethods =
-    static_cast<int>(GrpcMethod::kCheckpointNotify) + 1;
+    static_cast<int>(GrpcMethod::kGetMonomerBarrier) + 1;
 
 inline const char* GrpcMethodName(GrpcMethod id) {
   switch (id) {
@@ -92,6 +94,10 @@ inline const char* GrpcMethodName(GrpcMethod id) {
       return "/sendrecv.SendRecvService/SendVariable";
     case GrpcMethod::kGetVariable:
       return "/sendrecv.SendRecvService/GetVariable";
+    case GrpcMethod::kGetMonomerVariable:
+      return "/sendrecv.SendRecvService/GetMonomerVariable";
+    case GrpcMethod::kGetMonomerBarrier:
+      return "/sendrecv.SendRecvService/GetMonomerBarrier";
     case GrpcMethod::kPrefetchVariable:
       return "/sendrecv.SendRecvService/PrefetchVariable";
     case GrpcMethod::kCheckpointNotify:
diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h
index 5272afd428..62b24f150b 100644
--- a/paddle/fluid/operators/distributed/request_handler.h
+++ b/paddle/fluid/operators/distributed/request_handler.h
@@ -37,6 +37,8 @@ namespace distributed {
 
 constexpr char kRequestSend[] = "RequestSend";
 constexpr char kRequestGet[] = "RequestGet";
+constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
+constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
 constexpr char kRequestPrefetch[] = "RequestPrefetch";
 constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
 constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h
index 4cd3abb5a6..b668d86978 100644
--- a/paddle/fluid/operators/distributed/rpc_client.h
+++ b/paddle/fluid/operators/distributed/rpc_client.h
@@ -45,6 +45,11 @@ class RPCClient {
                                    const std::string& var_name,
                                    int64_t time_out = FLAGS_rpc_deadline) = 0;
 
+  virtual VarHandlePtr AsyncGetMonomerVariable(
+      const std::string& ep, const platform::DeviceContext& ctx,
+      const framework::Scope& scope, const std::string& var_name,
+      int64_t time_out = FLAGS_rpc_deadline) = 0;
+
   virtual VarHandlePtr AsyncPrefetchVar(
       const std::string& ep, const platform::DeviceContext& ctx,
       const framework::Scope& scope, const std::string& in_var_name,
@@ -57,6 +62,10 @@ class RPCClient {
   virtual VarHandlePtr AsyncSendFetchBarrier(
       const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
 
+  virtual VarHandlePtr AsyncGetMonomerBarrier(
+      const std::string& ep, const std::string& var_name,
+      int64_t time_out = FLAGS_rpc_deadline) = 0;
+
   virtual VarHandlePtr AsyncCheckpointNotify(
       const std::string& ep, const std::string& dir,
       int64_t time_out = FLAGS_rpc_deadline) = 0;
@@ -87,8 +96,9 @@ class RPCClient {
     }
   }
 
- protected:
   virtual void InitImpl() {}
+
+ protected:
   // each trainer have exact one trainer id, it should be static
   static int trainer_id_;
 
diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc
index 3e30ed4ac8..122619d41b 100644
--- a/paddle/fluid/operators/distributed/rpc_server.cc
+++ b/paddle/fluid/operators/distributed/rpc_server.cc
@@ -132,6 +132,96 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
       lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
 }
 
+void RPCServer::RegisterVar(const std::string& var_name,
+                            const std::string& rpc_name,
+                            framework::Scope* scope,
+                            platform::DeviceContext* dev_ctx) {
+  MonomerHandle h;
+  h.var_name_ = var_name;
+  h.rpc_name_ = rpc_name;
+  h.scope_ = scope;
+  h.dev_ctx_ = dev_ctx;
+
+  {
+    std::unique_lock<std::mutex> lock(mutex_);
+    if (var_map_.find(var_name) != var_map_.end()) {
+      PADDLE_ENFORCE(false, "%s alreay in var_map", var_name);
+    }
+    var_map_[var_name] = h;
+  }
+
+  rpc_cond_.notify_all();
+  VLOG(4) << "RegisterVar context:" << h.String();
+}
+
+void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
+  int b = 0;
+  MonomerHandle h;
+  {
+    std::unique_lock<std::mutex> lock(mutex_);
+    b = ++var_map_[var_name].barrier_;
+    h = var_map_[var_name];
+  }
+
+  if (b >= client_num_) {
+    barrier_cond_.notify_all();
+  }
+
+  VLOG(4) << "IncreaseVarBarrier context:" << h.String();
+}
+
+void RPCServer::WaitVarBarrier(const std::string& var_name) {
+  VLOG(4) << "WaitBarrier var_name:" << var_name;
+
+  std::unique_lock<std::mutex> lock(mutex_);
+  barrier_cond_.wait(lock, [&]() {
+    return ((var_map_[var_name].barrier_ >= client_num_ && client_num_ != 0) ||
+            exit_flag_.load());
+  });
+
+  VLOG(4) << "WaitBarrier context: " << var_map_[var_name].String();
+}
+
+void RPCServer::SetVarCond(const std::string& var_name) {
+  VLOG(4) << "SetVarCond var_name:" << var_name;
+  {
+    std::unique_lock<std::mutex> lock(mutex_);
+    if (var_map_.find(var_name) != var_map_.end()) {
+      rpc_cond_.notify_all();
+    }
+  }
+}
+
+void RPCServer::WaitVarCond(const std::string& var_name) {
+  VLOG(4) << "WaitVarCond var_name:" << var_name;
+
+  std::unique_lock<std::mutex> lock(mutex_);
+  rpc_cond_.wait(lock, [=] {
+    return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load());
+  });
+
+  VLOG(4) << "WaitVarCond var_name:" << var_name << " end";
+}
+
+MonomerHandle RPCServer::GetMonomer(const std::string& var_name) {
+  MonomerHandle h;
+  {
+    std::unique_lock<std::mutex> lock(mutex_);
+    h = var_map_[var_name];
+  }
+
+  return h;
+}
+
+void RPCServer::ClearRegisteredVars() {
+  std::unique_lock<std::mutex> lock(mutex_);
+  var_map_.clear();
+}
+
+void RPCServer::ClearVar(const std::string& var_name) {
+  std::unique_lock<std::mutex> lock(mutex_);
+  var_map_.erase(var_name);
+}
 }  // namespace distributed
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h
index c78c5007a7..45d1d3479c 100644
--- a/paddle/fluid/operators/distributed/rpc_server.h
+++ b/paddle/fluid/operators/distributed/rpc_server.h
@@ -21,12 +21,30 @@
 #include <utility>
 #include <vector>
 
+#include "paddle/fluid/framework/scope.h"
 #include "paddle/fluid/operators/distributed/request_handler.h"
+#include "paddle/fluid/platform/device_context.h"
 
 namespace paddle {
 namespace operators {
 namespace distributed {
 
+struct MonomerHandle {
+  std::string var_name_;
+  std::string rpc_name_;
+  framework::Scope* scope_{nullptr};
+  platform::DeviceContext* dev_ctx_{nullptr};
+  int64_t barrier_{0};
+
+  std::string String() {
+    std::stringstream ss;
+    ss << "var_name:" << var_name_ << ", rpc_name:" << rpc_name_
+       << ", scope:" << scope_ << ", dev_ctx:" << dev_ctx_
+       << ", barrier_:" << barrier_;
+    return ss.str();
+  }
+};
+
 class RPCServer {
  public:
   explicit RPCServer(const std::string& address, int client_num)
@@ -67,6 +85,16 @@ class RPCServer {
   void WaitCond(const std::string& rpc_name);
   void IncreaseBatchBarrier(const std::string rpc_name);
 
+  void RegisterVar(const std::string& var_name, const std::string& rpc_name,
+                   framework::Scope* scope, platform::DeviceContext* dev_ctx);
+  void IncreaseVarBarrier(const std::string& var_name);
+  void WaitVarBarrier(const std::string& var_name);
+  void SetVarCond(const std::string& var_name);
+  void WaitVarCond(const std::string& var_name);
+  void ClearRegisteredVars();
+  void ClearVar(const std::string& var_name);
+  MonomerHandle GetMonomer(const std::string& var_name);
+
   void Complete();
 
   void ResetBarrierCounter();
@@ -95,6 +123,9 @@ class RPCServer {
   std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
   std::unordered_map<std::string, int> rpc_thread_num_;
   friend class RequestHandler;
+
+  // TODO(gongwb): use more cond to notify or wait;
+  std::unordered_map<std::string, MonomerHandle> var_map_;
 };
 
 };  // namespace distributed
diff --git a/paddle/fluid/operators/distributed/send_recv.proto.in b/paddle/fluid/operators/distributed/send_recv.proto.in
index 7b7d069f17..2637619f30 100644
--- a/paddle/fluid/operators/distributed/send_recv.proto.in
+++ b/paddle/fluid/operators/distributed/send_recv.proto.in
@@ -28,6 +28,9 @@ service SendRecvService {
   rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
 
   rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
+
+  rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
+  rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
 }
 
 // VariableMessage is serialized paddle variable message.
diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h
index dc25bc5710..a8b8a67a11 100644
--- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h
+++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h
@@ -60,15 +60,37 @@ template <typename DeviceContext, typename T>
 class ElementwiseMulKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    auto* x = ctx.Input<framework::LoDTensor>("X");
+    auto x_var = ctx.InputVar("X");
+    PADDLE_ENFORCE(x_var != nullptr,
+                   "Cannot get input Variable X, variable name = %s",
+                   ctx.op().Input("X"));
     auto* y = ctx.Input<framework::LoDTensor>("Y");
-    auto* z = ctx.Output<framework::LoDTensor>("Out");
+
+    framework::Tensor x, *z;
+    if (x_var->IsType<framework::SelectedRows>()) {
+      PADDLE_ENFORCE(y->dims().size() == 1 && y->dims()[0] == 1,
+                     "For elementwise_op, if X is Sparse, Y must be scalar.");
+      auto& x_sele = x_var->Get<framework::SelectedRows>();
+      auto out_sele = ctx.Output<framework::SelectedRows>("Out");
+      x = x_sele.value();
+      out_sele->set_rows(x_sele.rows());
+      out_sele->set_height(x_sele.height());
+      out_sele->mutable_value()->Resize(x_sele.value().dims());
+      out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
+      z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
+    } else if (x_var->IsType<framework::LoDTensor>()) {
+      x = x_var->Get<framework::LoDTensor>();
+      z = ctx.Output<framework::LoDTensor>("Out");
+    } else {
+      PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
+                   x_var->Type().name());
+    }
 
     z->mutable_data<T>(ctx.GetPlace());
-    if (x->numel() == y->numel()) {
-      elementwise_mul<DeviceContext, T>(ctx, x, y, z);
+    if (x.numel() == y->numel()) {
+      elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
     } else {
-      default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
+      default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
     }
   }
 };
diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h
index 85a7817be9..87bf7c6b15 100644
--- a/paddle/fluid/operators/elementwise/elementwise_op.h
+++ b/paddle/fluid/operators/elementwise/elementwise_op.h
@@ -40,21 +40,28 @@ class ElementwiseOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(ctx->HasOutput("Out"),
                    "Output(Out) of elementwise op should not be null.");
 
-    PADDLE_ENFORCE(
-        ctx->GetInputsVarType("X").front() ==
-            framework::proto::VarType::LOD_TENSOR,
-        "The input var's type should be LoDTensor, but the received is %s",
-        ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
     PADDLE_ENFORCE(
         ctx->GetInputsVarType("Y").front() ==
             framework::proto::VarType::LOD_TENSOR,
-        "The input var's type should be LoDTensor, but the received is %s",
-        ctx->Inputs("Y").front(), ctx->GetInputsVarType("Y").front());
-
-    auto x_dim = ctx->GetInputDim("X");
-    auto y_dim = ctx->GetInputDim("Y");
-    PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
-                      "Rank of first input must >= rank of second input.");
+        "The input var's type should be LoDTensor, but the received is %s [%s]",
+        ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
+
+    if (ctx->GetInputsVarType("X").front() ==
+        framework::proto::VarType::LOD_TENSOR) {
+      auto x_dim = ctx->GetInputDim("X");
+      auto y_dim = ctx->GetInputDim("Y");
+      PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
+                        "Rank of first input must >= rank of second input.");
+    } else if (ctx->GetInputsVarType("X").front() ==
+               framework::proto::VarType::SELECTED_ROWS) {
+      PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
+                         (ctx->GetInputDim("Y")[0] == 1),
+                     "For elementwise_op, if X is Sparse, "
+                     "Y must be scalar.");
+    } else {
+      PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
+                   ctx->GetInputsVarType("X").front());
+    }
 
     ctx->ShareDim("X", /*->*/ "Out");
     ctx->ShareLoD("X", /*->*/ "Out");
diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
index 6d463538d2..1eb6523a2d 100644
--- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
+++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
@@ -217,13 +217,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
   auto& act_gate_str = ctx.Attr<std::string>("gate_activation");               \
   auto& act_cell_str = ctx.Attr<std::string>("cell_activation");               \
   auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");          \
-  if (platform::jit::MayIUse(platform::jit::avx)) {                            \
-    math::VecActivations<T, platform::jit::avx> act_functor;                   \
+  if (platform::MayIUse(platform::avx)) {                                      \
+    math::VecActivations<T, platform::avx> act_functor;                        \
     act_gate = act_functor(act_gate_str);                                      \
     act_cell = act_functor(act_cell_str);                                      \
     act_cand = act_functor(act_cand_str);                                      \
   } else {                                                                     \
-    math::VecActivations<T, platform::jit::isa_any> act_functor;               \
+    math::VecActivations<T, platform::isa_any> act_functor;                    \
     act_gate = act_functor(act_gate_str);                                      \
     act_cell = act_functor(act_cell_str);                                      \
     act_cand = act_functor(act_cand_str);                                      \
diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
index 288b56fc24..17ed9771d0 100644
--- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
+++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
@@ -151,11 +151,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
 
     std::function<void(const int, const T*, T*)> fc_act;
     auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
-    if (platform::jit::MayIUse(platform::jit::avx)) {
-      math::VecActivations<T, platform::jit::avx> act_functor;
+    if (platform::MayIUse(platform::avx)) {
+      math::VecActivations<T, platform::avx> act_functor;
       fc_act = act_functor(fc_act_str);
     } else {
-      math::VecActivations<T, platform::jit::isa_any> act_functor;
+      math::VecActivations<T, platform::isa_any> act_functor;
       fc_act = act_functor(fc_act_str);
     }
 
diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
new file mode 100644
index 0000000000..a4ae19d9c1
--- /dev/null
+++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc
@@ -0,0 +1,117 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/tensor_util.h"
+
+namespace paddle {
+namespace operators {
+
+class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext *ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "GetTensorFromSelectedRowsOp must has input X.");
+    PADDLE_ENFORCE(ctx->HasOutput("Out"),
+                   "GetTensorFromSelectedRowsOp must has output Out.");
+    PADDLE_ENFORCE(
+        ctx->GetInputsVarType("X").front() ==
+            framework::proto::VarType::SELECTED_ROWS,
+        "The input X's type should be SelectedRows, but the received is %s",
+        ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
+    PADDLE_ENFORCE(
+        ctx->GetOutputsVarType("Out").front() ==
+            framework::proto::VarType::LOD_TENSOR,
+        "The output Out's type should be LoDTensor, but the received is %s",
+        ctx->Outputs("Out").front(), ctx->GetOutputsVarType("Out").front());
+
+    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext &ctx) const override {
+    return framework::OpKernelType(
+        framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context());
+  }
+};
+
+class GetTensorFromSelectedRowsKernel {
+ public:
+  void operator()(const framework::ExecutionContext &ctx) const {
+    auto *x = ctx.Input<framework::SelectedRows>("X");
+    auto *out = ctx.Output<framework::LoDTensor>("Out");
+
+    out->Resize(x->value().dims());
+    out->mutable_data(ctx.GetPlace(), x->value().type());
+    framework::TensorCopy(x->value(), ctx.GetPlace(), ctx.device_context(),
+                          out);
+  }
+};
+
+class GetTensorFromSelectedRowsOpProtoMaker
+    : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("X", "The input type is SelectedRows.");
+    AddOutput("Out", "The output type is LoDTensor.");
+    AddComment(
+        R"DOC(
+GetTensorFromSelectedRows Operator
+
+GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
+
+)DOC");
+  }
+};
+
+class GetTensorFromSelectedRowsOpVarTypeInference
+    : public framework::VarTypeInference {
+ public:
+  void operator()(const framework::OpDesc &op_desc,
+                  framework::BlockDesc *block) const final {
+    auto out_var_name = op_desc.Output("Out").front();
+    auto in_var_name = op_desc.Input("X").front();
+
+    auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
+    auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
+    out_var.SetType(framework::proto::VarType::LOD_TENSOR);
+    out_var.SetDataType(in_var.GetDataType());
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(get_tensor_from_selected_rows,
+                  ops::GetTensorFromSelectedRowsOp,
+                  ops::GetTensorFromSelectedRowsOpProtoMaker,
+                  ops::GetTensorFromSelectedRowsOpVarTypeInference);
+
+REGISTER_OP_CPU_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float,
+                               ops::GetTensorFromSelectedRowsKernel, double,
+                               ops::GetTensorFromSelectedRowsKernel, int,
+                               ops::GetTensorFromSelectedRowsKernel, int64_t,
+                               ops::GetTensorFromSelectedRowsKernel);
+
+#ifdef PADDLE_WITH_CUDA
+REGISTER_OP_CUDA_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float,
+                                ops::GetTensorFromSelectedRowsKernel, double,
+                                ops::GetTensorFromSelectedRowsKernel, int,
+                                ops::GetTensorFromSelectedRowsKernel, int64_t,
+                                ops::GetTensorFromSelectedRowsKernel);
+#endif
diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc
index 972dcf5494..0dbcc442df 100644
--- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc
+++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc
@@ -150,14 +150,14 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
                    "Output(W@Grad should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
                    "Output(X@Grad should not be null.");
-    if (!ctx->Attrs().Get<bool>("is_sparse")) {
-      if (ctx->HasOutput(framework::GradVarName("Bias"))) {
-        ctx->SetOutputDim(framework::GradVarName("Bias"),
-                          ctx->GetInputDim("Bias"));
-      }
-      ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
+
+    if (ctx->HasOutput(framework::GradVarName("Bias"))) {
+      ctx->SetOutputDim(framework::GradVarName("Bias"),
+                        ctx->GetInputDim("Bias"));
     }
+    ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
     ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
+    ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
   }
 
  protected:
diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h
index 07ff8f947e..b73a32af89 100644
--- a/paddle/fluid/operators/hierarchical_sigmoid_op.h
+++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h
@@ -185,7 +185,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
           ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
       w_grad->set_rows(real_rows);
       // Build a map of id -> row_index to speed up finding the index of one id
-      w_grad->SyncIndex();
       w_grad->set_height(w.dims()[0]);
       auto* w_grad_value = w_grad->mutable_value();
       framework::DDim temp_dim(w.dims());
diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc
index 0522a94195..9d1423915a 100644
--- a/paddle/fluid/operators/load_combine_op.cc
+++ b/paddle/fluid/operators/load_combine_op.cc
@@ -32,16 +32,26 @@ class LoadCombineOp : public framework::OperatorBase {
                const platform::Place &place) const override {
     auto filename = Attr<std::string>("file_path");
     auto load_as_fp16 = Attr<bool>("load_as_fp16");
-
-    std::ifstream fin(filename);
-    PADDLE_ENFORCE(static_cast<bool>(fin),
-                   "Cannot open file %s for load_combine op", filename);
-
+    auto model_from_memory = Attr<bool>("model_from_memory");
     auto out_var_names = Outputs("Out");
     PADDLE_ENFORCE_GT(
         static_cast<int>(out_var_names.size()), 0,
         "The number of output variables should be greater than 0.");
-
+    if (!model_from_memory) {
+      std::ifstream fin(filename);
+      PADDLE_ENFORCE(static_cast<bool>(fin),
+                     "Cannot open file %s for load_combine op", filename);
+      LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
+    } else {
+      PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
+      std::stringstream fin(filename);
+      LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
+    }
+  }
+  void LoadParamsFromBuffer(
+      const framework::Scope &scope, const platform::Place &place,
+      std::istream *buffer, bool load_as_fp16,
+      const std::vector<std::string> &out_var_names) const {
     platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
     auto &dev_ctx = *pool.Get(place);
 
@@ -54,11 +64,10 @@ class LoadCombineOp : public framework::OperatorBase {
       auto *tensor = out_var->GetMutable<framework::LoDTensor>();
 
       // Error checking
-      PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot read more from file %s",
-                     filename);
+      PADDLE_ENFORCE(static_cast<bool>(buffer), "Cannot read more");
 
       // Get data from fin to tensor
-      DeserializeFromStream(fin, tensor, dev_ctx);
+      DeserializeFromStream(*buffer, tensor, dev_ctx);
 
       auto in_dtype = framework::ToDataType(tensor->type());
       auto out_dtype =
@@ -103,11 +112,17 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
                          "LoDTensors will be loaded from \"file_path\".")
         .AddCustomChecker(
             [](const std::string &path) { return !path.empty(); });
+    AddAttr<bool>("model_from_memory",
+                  "(boolean, default false)"
+                  "If true, file_path is in memory, and LoDTensors will be "
+                  "loaded directly from memory")
+        .SetDefault(false);
     AddComment(R"DOC(
 LoadCombine Operator.
 
-LoadCombine operator loads LoDTensor variables from a file. The file should 
-contain one or more LoDTensors serialized using the SaveCombine operator. The 
+LoadCombine operator loads LoDTensor variables from a file, which could be 
+loaded in memory already. The file should contain one or more LoDTensors 
+serialized using the SaveCombine operator. The
 LoadCombine operator applies a deserialization strategy to appropriately load 
 the LodTensors, and this strategy complements the serialization strategy used 
 in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt
index 63363086ad..b3d2ea38eb 100644
--- a/paddle/fluid/operators/math/CMakeLists.txt
+++ b/paddle/fluid/operators/math/CMakeLists.txt
@@ -59,6 +59,7 @@ math_library(matrix_bit_code)
 
 math_library(unpooling)
 math_library(vol2col)
+math_library(prelu)
 
 cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
 cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h
index 7d81aee596..e1e4d168db 100644
--- a/paddle/fluid/operators/math/cpu_vec.h
+++ b/paddle/fluid/operators/math/cpu_vec.h
@@ -77,7 +77,7 @@ inline void vec_scal<double>(const int n, const double a, double* x) {
 #endif
 
 // MKL scal only support inplace, choose this if src and dst are not equal
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_scal(const int n, const T a, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = a * x[i];
@@ -85,12 +85,12 @@ inline void vec_scal(const int n, const T a, const T* x, T* y) {
 }
 
 template <>
-inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
-                                                const float* x, float* y) {
+inline void vec_scal<float, platform::avx>(const int n, const float a,
+                                           const float* x, float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_scal<float, platform::jit::isa_any>(n, a, x, y);
+    vec_scal<float, platform::isa_any>(n, a, x, y);
     return;
   }
   const int rest = n % block;
@@ -114,24 +114,24 @@ inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
     y[i] = a * x[i];
   }
 #else
-  vec_scal<float, platform::jit::isa_any>(n, a, x, y);
+  vec_scal<float, platform::isa_any>(n, a, x, y);
 #endif
 }
 
 template <>
-inline void vec_scal<float, platform::jit::avx2>(const int n, const float a,
-                                                 const float* x, float* y) {
-  vec_scal<float, platform::jit::avx>(n, a, x, y);
+inline void vec_scal<float, platform::avx2>(const int n, const float a,
+                                            const float* x, float* y) {
+  vec_scal<float, platform::avx>(n, a, x, y);
 }
 
 template <>
-inline void vec_scal<float, platform::jit::avx512f>(const int n, const float a,
-                                                    const float* x, float* y) {
+inline void vec_scal<float, platform::avx512f>(const int n, const float a,
+                                               const float* x, float* y) {
   // TODO(TJ): enable me
-  vec_scal<float, platform::jit::avx2>(n, a, x, y);
+  vec_scal<float, platform::avx2>(n, a, x, y);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = a - x[i];
@@ -139,12 +139,12 @@ inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
 }
 
 template <>
-inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
-                                                    const float* x, float* y) {
+inline void vec_bias_sub<float, platform::avx>(const int n, const float a,
+                                               const float* x, float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
+    vec_bias_sub<float, platform::isa_any>(n, a, x, y);
     return;
   }
   const int rest = n % block;
@@ -168,27 +168,25 @@ inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
     y[i] = a - x[i];
   }
 #else
-  vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
+  vec_bias_sub<float, platform::isa_any>(n, a, x, y);
 #endif
 }
 
 template <>
-inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
-                                                     const float* x, float* y) {
-  vec_bias_sub<float, platform::jit::avx>(n, a, x, y);
+inline void vec_bias_sub<float, platform::avx2>(const int n, const float a,
+                                                const float* x, float* y) {
+  vec_bias_sub<float, platform::avx>(n, a, x, y);
 }
 
 template <>
-inline void vec_bias_sub<float, platform::jit::avx512f>(const int n,
-                                                        const float a,
-                                                        const float* x,
-                                                        float* y) {
+inline void vec_bias_sub<float, platform::avx512f>(const int n, const float a,
+                                                   const float* x, float* y) {
   // TODO(TJ): enable me
-  vec_bias_sub<float, platform::jit::avx2>(n, a, x, y);
+  vec_bias_sub<float, platform::avx2>(n, a, x, y);
 }
 
 // out = x*y + (1-x)*z
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
   for (int i = 0; i < n; ++i) {
     out[i] = x[i] * y[i] + (static_cast<T>(1) - x[i]) * z[i];
@@ -196,13 +194,13 @@ inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
 }
 
 template <>
-inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
-                                                 const float* y, const float* z,
-                                                 float* out) {
+inline void vec_cross<float, platform::avx>(const int n, const float* x,
+                                            const float* y, const float* z,
+                                            float* out) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
+    vec_cross<float, platform::isa_any>(n, x, y, z, out);
     return;
   }
   const int rest = n % block;
@@ -228,25 +226,26 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
     out[i] = x[i] * y[i] + (1.f - x[i]) * z[i];
   }
 #else
-  vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
+  vec_cross<float, platform::isa_any>(n, x, y, z, out);
 #endif
 }
 
 template <>
-inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
-                                                  const float* y,
-                                                  const float* z, float* out) {
-  vec_cross<float, platform::jit::avx>(n, x, y, z, out);
+inline void vec_cross<float, platform::avx2>(const int n, const float* x,
+                                             const float* y, const float* z,
+                                             float* out) {
+  vec_cross<float, platform::avx>(n, x, y, z, out);
 }
 
 template <>
-inline void vec_cross<float, platform::jit::avx512f>(
-    const int n, const float* x, const float* y, const float* z, float* out) {
+inline void vec_cross<float, platform::avx512f>(const int n, const float* x,
+                                                const float* y, const float* z,
+                                                float* out) {
   // TODO(TJ): enable me
-  vec_cross<float, platform::jit::avx>(n, x, y, z, out);
+  vec_cross<float, platform::avx>(n, x, y, z, out);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = x[i] + a;
@@ -254,12 +253,12 @@ inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
 }
 
 template <>
-inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
-                                                    const float* x, float* y) {
+inline void vec_add_bias<float, platform::avx>(const int n, const float a,
+                                               const float* x, float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
+    vec_add_bias<float, platform::isa_any>(n, a, x, y);
     return;
   }
   const int rest = n % block;
@@ -283,32 +282,30 @@ inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
     y[i] = x[i] + a;
   }
 #else
-  vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
+  vec_add_bias<float, platform::isa_any>(n, a, x, y);
 #endif
 }
 
 template <>
-inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a,
-                                                     const float* x, float* y) {
-  vec_add_bias<float, platform::jit::avx>(n, a, x, y);
+inline void vec_add_bias<float, platform::avx2>(const int n, const float a,
+                                                const float* x, float* y) {
+  vec_add_bias<float, platform::avx>(n, a, x, y);
 }
 
 template <>
-inline void vec_add_bias<float, platform::jit::avx512f>(const int n,
-                                                        const float a,
-                                                        const float* x,
-                                                        float* y) {
+inline void vec_add_bias<float, platform::avx512f>(const int n, const float a,
+                                                   const float* x, float* y) {
   // TODO(TJ): enable me
-  vec_add_bias<float, platform::jit::avx2>(n, a, x, y);
+  vec_add_bias<float, platform::avx2>(n, a, x, y);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_identity(const int n, const T* x, T* y) {
   // do nothing
   return;
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_sigmoid(const int n, const T* x, T* y) {
   const T min = SIGMOID_THRESHOLD_MIN;
   const T max = SIGMOID_THRESHOLD_MAX;
@@ -323,12 +320,12 @@ inline void vec_sigmoid(const int n, const T* x, T* y) {
 }
 
 template <>
-inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
-                                                   float* y) {
+inline void vec_sigmoid<float, platform::avx>(const int n, const float* x,
+                                              float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
+    vec_sigmoid<float, platform::isa_any>(n, x, y);
     return;
   }
   const int rest = n % block;
@@ -377,25 +374,24 @@ inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
     y[i] = 1.f / (1.f + y[i]);
   }
 #else
-  vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
+  vec_sigmoid<float, platform::isa_any>(n, x, y);
 #endif
 }
 
 template <>
-inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x,
-                                                    float* y) {
-  vec_sigmoid<float, platform::jit::avx>(n, x, y);
+inline void vec_sigmoid<float, platform::avx2>(const int n, const float* x,
+                                               float* y) {
+  vec_sigmoid<float, platform::avx>(n, x, y);
 }
 
 template <>
-inline void vec_sigmoid<float, platform::jit::avx512f>(const int n,
-                                                       const float* x,
-                                                       float* y) {
+inline void vec_sigmoid<float, platform::avx512f>(const int n, const float* x,
+                                                  float* y) {
   // TODO(TJ): enable me
-  vec_sigmoid<float, platform::jit::avx2>(n, x, y);
+  vec_sigmoid<float, platform::avx2>(n, x, y);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_tanh(const int n, const T* x, T* y) {
   vec_scal<T, isa>(n, static_cast<T>(2), x, y);
   vec_sigmoid<T, isa>(n, y, y);
@@ -404,7 +400,7 @@ inline void vec_tanh(const int n, const T* x, T* y) {
 }
 
 // TODO(TJ): make relu clip
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_relu(const int n, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = x[i] > 0 ? x[i] : 0;
@@ -412,12 +408,12 @@ inline void vec_relu(const int n, const T* x, T* y) {
 }
 
 template <>
-inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
-                                                float* y) {
+inline void vec_relu<float, platform::avx>(const int n, const float* x,
+                                           float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block * 4) {
-    vec_relu<float, platform::jit::isa_any>(n, x, y);
+    vec_relu<float, platform::isa_any>(n, x, y);
     return;
   }
 
@@ -441,26 +437,26 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
 #undef MOVE_ONE_STEP
 
 #else
-  vec_relu<float, platform::jit::isa_any>(n, x, y);
+  vec_relu<float, platform::isa_any>(n, x, y);
 #endif
 }
 
 template <>
-inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
-                                                 float* y) {
-  vec_relu<float, platform::jit::avx>(n, x, y);
+inline void vec_relu<float, platform::avx2>(const int n, const float* x,
+                                            float* y) {
+  vec_relu<float, platform::avx>(n, x, y);
 }
 
 template <>
-inline void vec_relu<float, platform::jit::avx512f>(const int n, const float* x,
-                                                    float* y) {
+inline void vec_relu<float, platform::avx512f>(const int n, const float* x,
+                                               float* y) {
   // TODO(TJ): enable me
-  vec_relu<float, platform::jit::avx2>(n, x, y);
+  vec_relu<float, platform::avx2>(n, x, y);
 }
 
 // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 class VecActivations {
  public:
   std::function<void(const int, const T*, T*)> operator()(
diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc
index c37fa291a2..28eb9cadc9 100644
--- a/paddle/fluid/operators/math/cpu_vec_test.cc
+++ b/paddle/fluid/operators/math/cpu_vec_test.cc
@@ -104,38 +104,42 @@ void TestAndBench(const int n, std::function<void(const int, const T*, T*)> tgt,
 }
 
 TEST(CpuVecTest, sigmoid) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
-    TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
-    TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
-    TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512f>,
+    TestAndBench<float>(sz, vec_sigmoid<float, platform::avx>,
+                        ref_sigmoid<float>);
+    TestAndBench<float>(sz, vec_sigmoid<float, platform::avx2>,
+                        ref_sigmoid<float>);
+    TestAndBench<float>(sz, vec_sigmoid<float, platform::avx512f>,
                         ref_sigmoid<float>);
   }
   TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
 }
 
 TEST(CpuVecTest, tanh) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
-    TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
-    TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
-    TestAndBench<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
+    TestAndBench<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
+    TestAndBench<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
+    TestAndBench<float>(sz, vec_tanh<float, platform::avx512f>,
+                        ref_tanh<float>);
   }
   TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
 }
 
 TEST(CpuVecTest, relu) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
-    TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
-    TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
-    TestAndBench<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
+    TestAndBench<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
+    TestAndBench<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
+    TestAndBench<float>(sz, vec_relu<float, platform::avx512f>,
+                        ref_relu<float>);
   }
   TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
 }
@@ -162,38 +166,40 @@ void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt,
 }
 
 TEST(CpuVecTest, inplace_sigmoid) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
-    TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
-    TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
-    TestInplace<float>(sz, vec_sigmoid<float, jit::avx512f>,
+    TestInplace<float>(sz, vec_sigmoid<float, platform::avx>,
+                       ref_sigmoid<float>);
+    TestInplace<float>(sz, vec_sigmoid<float, platform::avx2>,
+                       ref_sigmoid<float>);
+    TestInplace<float>(sz, vec_sigmoid<float, platform::avx512f>,
                        ref_sigmoid<float>);
   }
   TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
 }
 
 TEST(CpuVecTest, inplace_tanh) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
-    TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
-    TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
-    TestInplace<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
+    TestInplace<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
+    TestInplace<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
+    TestInplace<float>(sz, vec_tanh<float, platform::avx512f>, ref_tanh<float>);
   }
   TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
 }
 
 TEST(CpuVecTest, inplace_relu) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
-    TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
-    TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
-    TestInplace<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
+    TestInplace<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
+    TestInplace<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
+    TestInplace<float>(sz, vec_relu<float, platform::avx512f>, ref_relu<float>);
   }
   TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
 }
diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc
index 52cbdf685d..78d0c3e880 100644
--- a/paddle/fluid/operators/math/jit_code.cc
+++ b/paddle/fluid/operators/math/jit_code.cc
@@ -22,7 +22,7 @@ namespace math {
 namespace jitkernel {
 namespace gen {
 
-using namespace platform::jit;  // NOLINT
+using namespace platform;  // NOLINT
 
 bool VXXJitCode::init(int d, int scalar_index) {
   // It's not necessary to use avx512 since it would slow down the frequency
diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h
index a921462129..e2b4761435 100644
--- a/paddle/fluid/operators/math/jit_code.h
+++ b/paddle/fluid/operators/math/jit_code.h
@@ -179,7 +179,7 @@ class VActJitCode : public JitCode {
   template <typename JMM>
   void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12,  // NOLINT
                int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
-    using namespace platform::jit;  // NOLINT
+    using namespace platform;  // NOLINT
     // check all idx can not equal
     JMM jmm_src = JMM(src_idx);
     JMM jmm_fx = JMM(fx_idx);
diff --git a/paddle/fluid/operators/math/jit_gen.cc b/paddle/fluid/operators/math/jit_gen.cc
index 6af39518ed..5c6672928e 100644
--- a/paddle/fluid/operators/math/jit_gen.cc
+++ b/paddle/fluid/operators/math/jit_gen.cc
@@ -36,7 +36,7 @@ void JitCode::preCode() {
   for (int i = 0; i < num_g_abi_regs; ++i) {
     push(Xbyak::Reg64(g_abi_regs[i]));
   }
-  if (platform::jit::MayIUse(platform::jit::avx512f)) {
+  if (platform::MayIUse(platform::avx512f)) {
     mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
   }
 }
diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc
index 68b708b345..118696ba47 100644
--- a/paddle/fluid/operators/math/jit_kernel.cc
+++ b/paddle/fluid/operators/math/jit_kernel.cc
@@ -21,8 +21,6 @@ namespace operators {
 namespace math {
 namespace jitkernel {
 
-namespace jit = platform::jit;
-
 KernelPool& KernelPool::Instance() {
   static thread_local KernelPool g_jit_kernels;
   return g_jit_kernels;
diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc
index a0f93fd8e7..8cf588efba 100644
--- a/paddle/fluid/operators/math/jit_kernel_blas.cc
+++ b/paddle/fluid/operators/math/jit_kernel_blas.cc
@@ -30,7 +30,6 @@ namespace paddle {
 namespace operators {
 namespace math {
 namespace jitkernel {
-namespace jit = platform::jit;
 
 #ifdef PADDLE_WITH_MKLML
 template <typename T>
@@ -125,7 +124,7 @@ bool VMulKernelImpl<float>::useJIT(int d) {
 #ifdef PADDLE_WITH_MKLML
 template <>
 bool VMulKernelImpl<float>::useMKL(int d) {
-  return jit::MayIUse(jit::avx512f) && d > 512;
+  return platform::MayIUse(platform::avx512f) && d > 512;
 }
 
 template <>
diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc
index 4d26b81948..eeb305a88b 100644
--- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc
@@ -25,10 +25,8 @@ namespace operators {
 namespace math {
 namespace jitkernel {
 
-namespace jit = platform::jit;
-
 /* CRF Decode JitKernel */
-template <typename T, platform::jit::cpu_isa_t isa, jit_block>
+template <typename T, platform::cpu_isa_t isa, jit_block>
 class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
  public:
   explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel<T>() {
@@ -101,7 +99,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
 
 #define INTRIAVX_FLOAT(block)                                                  \
   template <>                                                                  \
-  CRFDecodeKernelImpl<float, jit::avx, block>::CRFDecodeKernelImpl(            \
+  CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl(       \
       int tag_num)                                                             \
       : CRFDecodeKernel<float>() {                                             \
     this->num_ = tag_num;                                                      \
@@ -109,7 +107,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
     this->rest_ = this->num_ % YMM_FLOAT_BLOCK;                                \
   }                                                                            \
   template <>                                                                  \
-  void CRFDecodeKernelImpl<float, jit::avx, block>::Compute(                   \
+  void CRFDecodeKernelImpl<float, platform::avx, block>::Compute(              \
       const int seq_len, const float* x, const float* w, float* alpha,         \
       int* track) const {                                                      \
     INIT_ALPHA(YMM_FLOAT_BLOCK)                                                \
@@ -204,7 +202,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
 
 #define INTRIAVX512_FLOAT(block)                                               \
   template <>                                                                  \
-  CRFDecodeKernelImpl<float, jit::avx512f, block>::CRFDecodeKernelImpl(        \
+  CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl(   \
       int tag_num)                                                             \
       : CRFDecodeKernel<float>() {                                             \
     this->num_ = tag_num;                                                      \
@@ -212,7 +210,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
     this->rest_ = this->num_ % ZMM_FLOAT_BLOCK;                                \
   }                                                                            \
   template <>                                                                  \
-  void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute(               \
+  void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute(          \
       const int seq_len, const float* x, const float* w, float* alpha,         \
       int* track) const {                                                      \
     INIT_ALPHA(ZMM_FLOAT_BLOCK)                                                \
@@ -270,14 +268,14 @@ INTRIAVX_FLOAT(kEQ16);
 INTRIAVX_FLOAT(kGT16);
 #endif
 #ifdef __AVX2__
-INTRIAVX2_FLOAT(jit::avx2, kEQ8);
-INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
-INTRIAVX2_FLOAT(jit::avx2, kEQ16);
-INTRIAVX2_FLOAT(jit::avx2, kGT16);
+INTRIAVX2_FLOAT(platform::avx2, kEQ8);
+INTRIAVX2_FLOAT(platform::avx2, kGT8LT16);
+INTRIAVX2_FLOAT(platform::avx2, kEQ16);
+INTRIAVX2_FLOAT(platform::avx2, kGT16);
 #endif
 #ifdef __AVX512F__
-INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
-INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
+INTRIAVX2_FLOAT(platform::avx512f, kEQ8);
+INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16);
 INTRIAVX512_FLOAT(kEQ16);
 INTRIAVX512_FLOAT(kGT16);
 #endif
diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc
index 686f3dd983..7945cfb253 100644
--- a/paddle/fluid/operators/math/jit_kernel_exp.cc
+++ b/paddle/fluid/operators/math/jit_kernel_exp.cc
@@ -29,7 +29,6 @@ namespace paddle {
 namespace operators {
 namespace math {
 namespace jitkernel {
-namespace jit = platform::jit;
 
 #ifdef PADDLE_WITH_MKLML
 // try to use MKL to speedup
diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc
index 49904e6e8c..fead13ebad 100644
--- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc
+++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc
@@ -22,10 +22,8 @@ namespace operators {
 namespace math {
 namespace jitkernel {
 
-namespace jit = platform::jit;
-
 /* Layer Norm JitKernel */
-template <typename T, platform::jit::cpu_isa_t isa, jit_block>
+template <typename T, platform::cpu_isa_t isa, jit_block>
 class LayerNormKernelImpl : public LayerNormKernel<T> {
  public:
   explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
@@ -90,7 +88,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
     this->end_ = this->num_ - this->rest_;                                     \
   }                                                                            \
   template <>                                                                  \
-  void LayerNormKernelImpl<float, jit::avx, block>::Compute(                   \
+  void LayerNormKernelImpl<float, platform::avx, block>::Compute(              \
       float* x, float* out, float* mean, float* var, const float* scale,       \
       const float* bias, int height, const float epsilon) const {              \
     __m256 sum;                                                                \
@@ -219,16 +217,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
   }
 
 #ifdef __AVX__
-INTRIAVX_FLOAT(jit::avx, kEQ8);
-INTRIAVX_FLOAT(jit::avx, kGT8LT16);
-INTRIAVX_FLOAT(jit::avx, kEQ16);
-INTRIAVX_FLOAT(jit::avx, kGT16);
+INTRIAVX_FLOAT(platform::avx, kEQ8);
+INTRIAVX_FLOAT(platform::avx, kGT8LT16);
+INTRIAVX_FLOAT(platform::avx, kEQ16);
+INTRIAVX_FLOAT(platform::avx, kGT16);
 #endif
 #ifdef __AVX2__
-INTRIAVX_FLOAT(jit::avx2, kEQ8);
-INTRIAVX_FLOAT(jit::avx2, kGT8LT16);
-INTRIAVX_FLOAT(jit::avx2, kEQ16);
-INTRIAVX_FLOAT(jit::avx2, kGT16);
+INTRIAVX_FLOAT(platform::avx2, kEQ8);
+INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
+INTRIAVX_FLOAT(platform::avx2, kEQ16);
+INTRIAVX_FLOAT(platform::avx2, kGT16);
 #endif
 
 #undef INTRIAVX_FLOAT
diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h
index 5a3efd979f..4dba3b5681 100644
--- a/paddle/fluid/operators/math/jit_kernel_macro.h
+++ b/paddle/fluid/operators/math/jit_kernel_macro.h
@@ -92,7 +92,6 @@ namespace jitkernel {
                           JITKERNEL_DECLARE, JITKERNEL_FIND_KEY,     \
                           JITKERNEL_IMPL)
 
-namespace jit = platform::jit;
 // TODO(TJ): below defines are deprecated, would be remove recently
 #define SEARCH_BLOCK(macro_, ker, dtype, isa)              \
   if (d < YMM_FLOAT_BLOCK) {                               \
@@ -107,15 +106,15 @@ namespace jit = platform::jit;
     macro_(ker, dtype, isa, kGT16);                        \
   }
 
-#define SEARCH_ISA_BLOCK(macro_, ker, dtype)        \
-  if (jit::MayIUse(jit::avx512f)) {                 \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \
-  } else if (jit::MayIUse(jit::avx2)) {             \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::avx2);    \
-  } else if (jit::MayIUse(jit::avx)) {              \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::avx);     \
-  } else {                                          \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
+#define SEARCH_ISA_BLOCK(macro_, ker, dtype)             \
+  if (platform::MayIUse(platform::avx512f)) {            \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
+  } else if (platform::MayIUse(platform::avx2)) {        \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::avx2);    \
+  } else if (platform::MayIUse(platform::avx)) {         \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::avx);     \
+  } else {                                               \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
   }
 
 #define JITKERNEL_KEY(ker_key, dtype_key) \
@@ -156,10 +155,10 @@ namespace jit = platform::jit;
                                   marco_declare, macro_key, macro_impl)
 
 #define FOR_EACH_ISA(macro_, block) \
-  macro_(jit::avx512f, block);      \
-  macro_(jit::avx2, block);         \
-  macro_(jit::avx, block);          \
-  macro_(jit::isa_any, block)
+  macro_(platform::avx512f, block); \
+  macro_(platform::avx2, block);    \
+  macro_(platform::avx, block);     \
+  macro_(platform::isa_any, block)
 
 #define FOR_EACH_BLOCK(macro_, isa) \
   macro_(isa, kLT8);                \
@@ -168,11 +167,11 @@ namespace jit = platform::jit;
   macro_(isa, kEQ16);               \
   macro_(isa, kGT16)
 
-#define FOR_EACH_ISA_BLOCK(macro_)      \
-  FOR_EACH_BLOCK(macro_, jit::avx512f); \
-  FOR_EACH_BLOCK(macro_, jit::avx2);    \
-  FOR_EACH_BLOCK(macro_, jit::avx);     \
-  FOR_EACH_BLOCK(macro_, jit::isa_any)
+#define FOR_EACH_ISA_BLOCK(macro_)           \
+  FOR_EACH_BLOCK(macro_, platform::avx512f); \
+  FOR_EACH_BLOCK(macro_, platform::avx2);    \
+  FOR_EACH_BLOCK(macro_, platform::avx);     \
+  FOR_EACH_BLOCK(macro_, platform::isa_any)
 
 }  // namespace jitkernel
 }  // namespace math
diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc
index ed86a47e15..19f7bd8909 100644
--- a/paddle/fluid/operators/math/jit_kernel_test.cc
+++ b/paddle/fluid/operators/math/jit_kernel_test.cc
@@ -705,7 +705,7 @@ TEST(JitKernel, pool) {
   jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
 
   // empty call it to avoid unknown flag 'use_pinned_memory' on Mac
-  paddle::platform::jit::MayIUse(paddle::platform::jit::avx);
+  paddle::platform::MayIUse(paddle::platform::avx);
   const auto& plstm1 =
       jit::KernelPool::Instance()
           .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc
index 71b9293eed..5a6e64b6f8 100644
--- a/paddle/fluid/operators/math/matrix_bit_code.cc
+++ b/paddle/fluid/operators/math/matrix_bit_code.cc
@@ -89,6 +89,8 @@ template <typename T>
 void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
                                   const framework::Tensor& weight,
                                   const framework::Tensor& input) {
+  auto blas =
+      GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
   size_t num_samples = tmat->dims()[0];
   size_t tmat_width = tmat->dims()[1];
   size_t input_width = input.dims()[1];
@@ -99,13 +101,12 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
   for (size_t i = 0; i < num_samples; ++i) {
     auto code = code_table_->get_code(i);
     int code_length = code->get_length();
+    const T* input_row = input_value + input_width * i;
     for (int j = 0; j < code_length; ++j) {
       size_t index = code->calc_index(j);
+      const T* weight_row = weight_value + weight_width * index;
       T sum = static_cast<T>(0.0);
-      for (size_t k = 0; k < input_width; ++k) {
-        sum += weight_value[weight_width * index + k] *
-               input_value[input_width * i + k];
-      }
+      sum = blas.DOT(input_width, weight_row, input_row);
       tmat_value[i * tmat_width + j] += sum;
     }
   }
@@ -115,6 +116,8 @@ template <typename T>
 void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
                                             framework::Tensor* weight,
                                             const framework::Tensor& input) {
+  auto blas =
+      GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
   size_t num_samples = tmat.dims()[0];
   size_t input_width = input.dims()[1];
   size_t tmat_width = tmat.dims()[1];
@@ -122,16 +125,25 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
   auto tmat_value = tmat.data<T>();
   auto weight_value = weight->data<T>();
   auto input_value = input.data<T>();
+
+  std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
+
   for (size_t i = 0; i < num_samples; ++i) {
     auto code = code_table_->get_code(i);
     int code_length = code->get_length();
+    const T* input_value_row = input_value + input_width * i;
+    const T* tmat_row = tmat_value + i * tmat_width;
     for (int j = 0; j < code_length; ++j) {
-      size_t index = code->calc_index(j);
-
-      for (size_t k = 0; k < input_width; ++k) {
-        weight_value[weight_width * index + k] +=
-            tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
-      }
+      ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
+    }
+  }
+  for (auto& op : ops) {
+    auto& op_in_row = op.second;
+    for (auto& pair : op_in_row) {
+      auto& scale = pair.first;
+      auto* input_row = pair.second;
+      T* weight_row = weight_value + op.first * weight_width;
+      blas.AXPY(input_width, scale, input_row, weight_row);
     }
   }
 }
@@ -140,6 +152,8 @@ template <typename T>
 void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
                                             framework::SelectedRows* weight,
                                             const framework::Tensor& input) {
+  auto blas =
+      GetBlas<platform::CPUDeviceContext, T>(platform::CPUDeviceContext());
   size_t num_samples = tmat.dims()[0];
   size_t input_width = input.dims()[1];
   size_t tmat_width = tmat.dims()[1];
@@ -147,17 +161,28 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
   auto tmat_value = tmat.data<T>();
   auto weight_value = weight->mutable_value()->data<T>();
   auto input_value = input.data<T>();
+
+  std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
+  ops.reserve(weight->rows().size());
+
   for (size_t i = 0; i < num_samples; ++i) {
     auto code = code_table_->get_code(i);
     int code_length = code->get_length();
+    const T* input_value_row = input_value + input_width * i;
+    const T* tmat_row = tmat_value + i * tmat_width;
     for (int j = 0; j < code_length; ++j) {
-      size_t index = code->calc_index(j);
-      for (size_t k = 0; k < input_width; ++k) {
-        int64_t row_index = weight->GetIndexFromId(static_cast<int64_t>(index));
-        weight_value[row_index * weight_width + k] +=
-            tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
-      }
+      ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row);
+    }
+  }
+
+  for (auto& row : weight->rows()) {
+    auto& op_in_row = ops[row];
+    for (auto& pair : op_in_row) {
+      auto& scale = pair.first;
+      auto* input_row = pair.second;
+      blas.AXPY(input_width, scale, input_row, weight_value);
     }
+    weight_value += weight_width;
   }
 }
 
diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h
index c30bb52641..35ca73802b 100644
--- a/paddle/fluid/operators/math/matrix_bit_code.h
+++ b/paddle/fluid/operators/math/matrix_bit_code.h
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #pragma once
+#include <unordered_map>
+#include <utility>
+#include <vector>
 #include "paddle/fluid/framework/eigen.h"
 #include "paddle/fluid/framework/lod_tensor.h"
 #include "paddle/fluid/framework/selected_rows.h"
 #include "paddle/fluid/framework/tensor.h"
+#include "paddle/fluid/operators/math/blas.h"
 #include "paddle/fluid/platform/device_context.h"
 
 #if defined(_WIN32)
diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu
new file mode 100644
index 0000000000..701a802080
--- /dev/null
+++ b/paddle/fluid/operators/math/prelu.cu
@@ -0,0 +1,148 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/math/prelu.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+static const int CUDA_NUM_THREADS = 1024;
+static const int CUDA_MAX_NUM_BLOCKS = 65535;
+inline static int GET_NUM_BLOCKS(const int N) {
+  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+template <typename T>
+__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
+                                       T *output, int channel,
+                                       size_t spatial_size) {
+  size_t offset = blockIdx.x * spatial_size;
+  const T *in = input + offset;
+  T *out = output + offset;
+  T scale = alpha[blockIdx.x % channel];
+
+  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
+    T x = in[i];
+    out[i] = (x > 0) ? x : scale * x;
+  }
+}
+
+template <typename T>
+__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
+                                       T *output, size_t spatial_size) {
+  size_t offset = blockIdx.x * spatial_size;
+  const T *in = input + offset;
+  const T *scale = alpha + offset;
+  T *out = output + offset;
+
+  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
+    T x = in[i];
+    out[i] = (x > 0) ? x : scale[i] * x;
+  }
+}
+
+template <typename T>
+__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
+                                  size_t spatial_size) {
+  size_t offset = blockIdx.x * spatial_size;
+  const T *in = input + offset;
+  T scale = *alpha;
+  T *out = output + offset;
+
+  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
+    T x = in[i];
+    out[i] = (x > 0) ? x : scale * x;
+  }
+}
+
+template <typename T>
+static inline void PReluChannelWise(cudaStream_t stream, const T *input,
+                                    const T *alpha, T *output,
+                                    std::vector<int> input_shape) {
+  size_t unroll = input_shape[0] * input_shape[1];
+  size_t spatial_size = input_shape[2] * input_shape[3];
+  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
+  PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
+      input, alpha, output, input_shape[1], spatial_size);
+}
+
+template <typename T>
+static inline void PReluElementWise(cudaStream_t stream, const T *input,
+                                    const T *alpha, T *output,
+                                    std::vector<int> input_shape) {
+  size_t unroll = input_shape[0] * input_shape[1];
+  size_t spatial_size = input_shape[2] * input_shape[3];
+  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
+  PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
+      input, alpha, output, spatial_size);
+}
+
+template <typename T>
+static inline void PReluScalar(cudaStream_t stream, const T *input,
+                               const T *alpha, T *output,
+                               std::vector<int> input_shape) {
+  size_t unroll = input_shape[0] * input_shape[1];
+  size_t spatial_size = input_shape[2] * input_shape[3];
+  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
+  PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
+      input, alpha, output, spatial_size);
+}
+
+template <typename T>
+void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
+    cudaStream_t stream, const T *input, const T *alpha, T *output,
+    std::vector<int> input_shape) {
+  size_t unroll = input_shape[0] * input_shape[1];
+  size_t spatial_size = input_shape[2] * input_shape[3];
+  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
+  PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
+      input, alpha, output, input_shape[1], spatial_size);
+}
+
+template <typename T>
+void PreluElementWiseDirectCUDAFunctor<T>::operator()(
+    cudaStream_t stream, const T *input, const T *alpha, T *output,
+    std::vector<int> input_shape) {
+  size_t unroll = input_shape[0] * input_shape[1];
+  size_t spatial_size = input_shape[2] * input_shape[3];
+  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
+  PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
+      input, alpha, output, spatial_size);
+}
+
+template <typename T>
+void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
+                                                 const T *input, const T *alpha,
+                                                 T *output,
+                                                 std::vector<int> input_shape) {
+  size_t unroll = input_shape[0] * input_shape[1];
+  size_t spatial_size = input_shape[2] * input_shape[3];
+  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
+  PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
+      input, alpha, output, spatial_size);
+}
+
+template class PreluChannelWiseDirectCUDAFunctor<float>;
+template class PreluChannelWiseDirectCUDAFunctor<double>;
+
+template class PreluElementWiseDirectCUDAFunctor<float>;
+template class PreluElementWiseDirectCUDAFunctor<double>;
+
+template class PreluScalarDirectCUDAFunctor<float>;
+template class PreluScalarDirectCUDAFunctor<double>;
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h
new file mode 100644
index 0000000000..3237c6d4cb
--- /dev/null
+++ b/paddle/fluid/operators/math/prelu.h
@@ -0,0 +1,49 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include <vector>
+#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/platform/cudnn_helper.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+#ifdef PADDLE_WITH_CUDA
+template <typename T>
+class PreluChannelWiseDirectCUDAFunctor {
+ public:
+  void operator()(cudaStream_t stream, const T *input, const T *alpha,
+                  T *output, std::vector<int> input_shape);
+};
+
+template <typename T>
+class PreluElementWiseDirectCUDAFunctor {
+ public:
+  void operator()(cudaStream_t stream, const T *input, const T *alpha,
+                  T *output, std::vector<int> input_shape);
+};
+
+template <typename T>
+class PreluScalarDirectCUDAFunctor {
+ public:
+  void operator()(cudaStream_t stream, const T *input, const T *alpha,
+                  T *output, std::vector<int> input_shape);
+};
+#endif
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h
index 31ed519666..9e99e44822 100644
--- a/paddle/fluid/operators/math/softmax_impl.h
+++ b/paddle/fluid/operators/math/softmax_impl.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #pragma once
+#include <vector>
 #include "paddle/fluid/framework/eigen.h"
 #include "paddle/fluid/framework/tensor.h"
 
diff --git a/paddle/fluid/operators/merge_selected_rows_op.cc b/paddle/fluid/operators/merge_selected_rows_op.cc
new file mode 100644
index 0000000000..3c15c83955
--- /dev/null
+++ b/paddle/fluid/operators/merge_selected_rows_op.cc
@@ -0,0 +1,72 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/merge_selected_rows_op.h"
+
+namespace paddle {
+namespace operators {
+
+class MergeSelectedRowsOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "Input(X) of MergeSelectedRowsOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("Out"),
+                   "Output(Out) of MergeSelectedRowsOp should not be null.");
+    ctx->ShareDim("X", /*->*/ "Out");
+  }
+};
+
+class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("X",
+             "The input type is SelectedRows, and the selected rows may be "
+             "duplicated.");
+    AddOutput("Out",
+              "The output type is SelectedRows, and the selected rows are not "
+              "duplicated.");
+    AddComment(
+        R"DOC(
+MergeSelectedRows Operator.
+
+MergeSelectedRows is used to merge the duplicated rows of the input.
+)DOC");
+  }
+};
+
+class MergeSelectedRowsOpInferVarType
+    : public framework::PassInDtypeAndVarTypeToOutput {
+ protected:
+  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
+      const override {
+    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+namespace plat = paddle::platform;
+REGISTER_OPERATOR(merge_selected_rows, ops::MergeSelectedRowsOp,
+                  ops::MergeSelectedRowsOpMaker,
+                  ops::MergeSelectedRowsOpInferVarType);
+
+REGISTER_OP_CPU_KERNEL(
+    merge_selected_rows,
+    ops::MergeSelectedRowsKernel<plat::CPUDeviceContext, float>,
+    ops::MergeSelectedRowsKernel<plat::CPUDeviceContext, double>);
diff --git a/paddle/fluid/operators/merge_selected_rows_op.cu.cc b/paddle/fluid/operators/merge_selected_rows_op.cu.cc
new file mode 100644
index 0000000000..90d5fb3eae
--- /dev/null
+++ b/paddle/fluid/operators/merge_selected_rows_op.cu.cc
@@ -0,0 +1,23 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/merge_selected_rows_op.h"
+
+namespace ops = paddle::operators;
+namespace plat = paddle::platform;
+
+REGISTER_OP_CUDA_KERNEL(
+    merge_selected_rows,
+    ops::MergeSelectedRowsKernel<plat::CUDADeviceContext, float>,
+    ops::MergeSelectedRowsKernel<plat::CUDADeviceContext, double>);
diff --git a/paddle/fluid/operators/merge_selected_rows_op.h b/paddle/fluid/operators/merge_selected_rows_op.h
new file mode 100644
index 0000000000..4c977e94b1
--- /dev/null
+++ b/paddle/fluid/operators/merge_selected_rows_op.h
@@ -0,0 +1,36 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include <string>
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/selected_rows_functor.h"
+
+namespace paddle {
+namespace operators {
+
+template <typename DeviceContext, typename T>
+class MergeSelectedRowsKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* x = context.Input<framework::SelectedRows>("X");
+    auto* out = context.Output<framework::SelectedRows>("Out");
+
+    math::scatter::MergeAdd<DeviceContext, T> merge_func;
+    merge_func(context.template device_context<DeviceContext>(), *x, out);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc
index 58cfbb76e9..64d94ab604 100644
--- a/paddle/fluid/operators/prelu_op.cc
+++ b/paddle/fluid/operators/prelu_op.cc
@@ -58,7 +58,7 @@ class PReluOp : public framework::OperatorWithKernel {
       const framework::ExecutionContext &ctx) const override {
     return framework::OpKernelType(
         framework::ToDataType(ctx.Input<Tensor>("X")->type()),
-        platform::CPUPlace());
+        ctx.device_context());
   }
 };
 
diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu
new file mode 100644
index 0000000000..36b5259ae5
--- /dev/null
+++ b/paddle/fluid/operators/prelu_op.cu
@@ -0,0 +1,64 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <string>
+#include <vector>
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/prelu.h"
+#include "paddle/fluid/operators/prelu_op.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template <typename DeviceContext, typename T>
+class CUDAPReluKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* x = context.Input<Tensor>("X");
+    auto* alpha = context.Input<Tensor>("Alpha");
+    auto* out = context.Output<Tensor>("Out");
+
+    const T* x_ptr = x->data<T>();
+    T* o_ptr = out->mutable_data<T>(context.GetPlace());
+
+    const T* alpha_ptr = alpha->data<T>();
+    auto& mode = context.Attr<std::string>("mode");
+
+    int numel = x->numel();
+    auto dim = x->dims();
+    std::vector<int> input_shape = framework::vectorize2int(dim);
+
+    if (mode == "channel") {
+      math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
+      prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
+                         alpha_ptr, o_ptr, input_shape);
+    } else if (mode == "element") {
+      math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
+      prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
+                         alpha_ptr, o_ptr, input_shape);
+    } else {
+      math::PreluScalarDirectCUDAFunctor<T> prelu_scalar;
+      prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr,
+                   o_ptr, input_shape);
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(
+    prelu, ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, float>,
+    ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, double>);
diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h
index 18acb735ce..8fceed3558 100644
--- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h
+++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h
@@ -36,12 +36,10 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
     PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
 
-    auto maxlen = ctx->Attrs().Get<int>("maxlen");
-    if (maxlen > 0) {  // We can only infershape when maxlen > 0
-      auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
-      dim.push_back(maxlen);
-      ctx->SetOutputDim("Y", framework::make_ddim(dim));
-    }
+    int maxlen = ctx->Attrs().Get<int>("maxlen");
+    auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
+    dim.push_back(maxlen > 0 ? maxlen : -1);
+    ctx->SetOutputDim("Y", framework::make_ddim(dim));
   }
 };
 
diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc
new file mode 100644
index 0000000000..e7597f7324
--- /dev/null
+++ b/paddle/fluid/operators/yolov3_loss_op.cc
@@ -0,0 +1,221 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/fluid/operators/yolov3_loss_op.h"
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+
+class Yolov3LossOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "Input(X) of Yolov3LossOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("GTBox"),
+                   "Input(GTBox) of Yolov3LossOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("GTLabel"),
+                   "Input(GTLabel) of Yolov3LossOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("Loss"),
+                   "Output(Loss) of Yolov3LossOp should not be null.");
+
+    auto dim_x = ctx->GetInputDim("X");
+    auto dim_gtbox = ctx->GetInputDim("GTBox");
+    auto dim_gtlabel = ctx->GetInputDim("GTLabel");
+    auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
+    auto class_num = ctx->Attrs().Get<int>("class_num");
+    PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
+    PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3],
+                      "Input(X) dim[3] and dim[4] should be euqal.");
+    PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num),
+                      "Input(X) dim[1] should be equal to (anchor_number * (5 "
+                      "+ class_num)).");
+    PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
+                      "Input(GTBox) should be a 3-D tensor");
+    PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5");
+    PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2,
+                      "Input(GTBox) should be a 2-D tensor");
+    PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0],
+                      "Input(GTBox) and Input(GTLabel) dim[0] should be same");
+    PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1],
+                      "Input(GTBox) and Input(GTLabel) dim[1] should be same");
+    PADDLE_ENFORCE_GT(anchors.size(), 0,
+                      "Attr(anchors) length should be greater then 0.");
+    PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
+                      "Attr(anchors) length should be even integer.");
+    PADDLE_ENFORCE_GT(class_num, 0,
+                      "Attr(class_num) should be an integer greater then 0.");
+
+    std::vector<int64_t> dim_out({1});
+    ctx->SetOutputDim("Loss", framework::make_ddim(dim_out));
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    return framework::OpKernelType(
+        framework::ToDataType(ctx.Input<Tensor>("X")->type()),
+        platform::CPUPlace());
+  }
+};
+
+class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("X",
+             "The input tensor of YOLO v3 loss operator, "
+             "This is a 4-D tensor with shape of [N, C, H, W]."
+             "H and W should be same, and the second dimention(C) stores"
+             "box locations, confidence score and classification one-hot"
+             "key of each anchor box");
+    AddInput("GTBox",
+             "The input tensor of ground truth boxes, "
+             "This is a 3-D tensor with shape of [N, max_box_num, 5], "
+             "max_box_num is the max number of boxes in each image, "
+             "In the third dimention, stores x, y, w, h coordinates, "
+             "x, y is the center cordinate of boxes and w, h is the "
+             "width and height and x, y, w, h should be divided by "
+             "input image height to scale to [0, 1].");
+    AddInput("GTLabel",
+             "The input tensor of ground truth label, "
+             "This is a 2-D tensor with shape of [N, max_box_num], "
+             "and each element shoudl be an integer to indicate the "
+             "box class id.");
+    AddOutput("Loss",
+              "The output yolov3 loss tensor, "
+              "This is a 1-D tensor with shape of [1]");
+
+    AddAttr<int>("class_num", "The number of classes to predict.");
+    AddAttr<std::vector<int>>("anchors",
+                              "The anchor width and height, "
+                              "it will be parsed pair by pair.");
+    AddAttr<float>("ignore_thresh",
+                   "The ignore threshold to ignore confidence loss.");
+    AddAttr<float>("loss_weight_xy", "The weight of x, y location loss.")
+        .SetDefault(1.0);
+    AddAttr<float>("loss_weight_wh", "The weight of w, h location loss.")
+        .SetDefault(1.0);
+    AddAttr<float>(
+        "loss_weight_conf_target",
+        "The weight of confidence score loss in locations with target object.")
+        .SetDefault(1.0);
+    AddAttr<float>("loss_weight_conf_notarget",
+                   "The weight of confidence score loss in locations without "
+                   "target object.")
+        .SetDefault(1.0);
+    AddAttr<float>("loss_weight_class", "The weight of classification loss.")
+        .SetDefault(1.0);
+    AddComment(R"DOC(
+         This operator generate yolov3 loss by given predict result and ground
+         truth boxes.
+         
+         The output of previous network is in shape [N, C, H, W], while H and W
+         should be the same, specify the grid size, each grid point predict given
+         number boxes, this given number is specified by anchors, it should be 
+         half anchors length, which following will be represented as S. In the 
+         second dimention(the channel dimention), C should be S * (class_num + 5),
+         class_num is the box categoriy number of source dataset(such as coco), 
+         so in the second dimention, stores 4 box location coordinates x, y, w, h 
+         and confidence score of the box and class one-hot key of each anchor box.
+
+         While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions
+         correspnd to:
+
+         $$
+         b_x = \sigma(t_x) + c_x
+         b_y = \sigma(t_y) + c_y
+         b_w = p_w e^{t_w}
+         b_h = p_h e^{t_h}
+         $$
+
+         While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$
+         is specified by anchors.
+
+         As for confidence score, it is the logistic regression value of IoU between
+         anchor boxes and ground truth boxes, the score of the anchor box which has 
+         the max IoU should be 1, and if the anchor box has IoU bigger then ignore 
+         thresh, the confidence score loss of this anchor box will be ignored.
+
+         Therefore, the yolov3 loss consist of three major parts, box location loss,
+         confidence score loss, and classification loss. The MSE loss is used for 
+         box location, and binary cross entropy loss is used for confidence score 
+         loss and classification loss.
+
+         Final loss will be represented as follow.
+
+         $$
+         loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh}
+              + \loss_weight_{conf_target} * loss_{conf_target}
+              + \loss_weight_{conf_notarget} * loss_{conf_notarget}
+              + \loss_weight_{class} * loss_{class}
+         $$
+         )DOC");
+  }
+};
+
+class Yolov3LossOpGrad : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
+                   "Input(Loss@GRAD) should not be null");
+    auto dim_x = ctx->GetInputDim("X");
+    if (ctx->HasOutput(framework::GradVarName("X"))) {
+      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
+    }
+  }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    return framework::OpKernelType(
+        framework::ToDataType(ctx.Input<Tensor>("X")->type()),
+        platform::CPUPlace());
+  }
+};
+
+class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+  std::unique_ptr<framework::OpDesc> Apply() const override {
+    auto* op = new framework::OpDesc();
+    op->SetType("yolov3_loss_grad");
+    op->SetInput("X", Input("X"));
+    op->SetInput("GTBox", Input("GTBox"));
+    op->SetInput("GTLabel", Input("GTLabel"));
+    op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
+
+    op->SetAttrMap(Attrs());
+
+    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
+    op->SetOutput(framework::GradVarName("GTBox"), {});
+    op->SetOutput(framework::GradVarName("GTLabel"), {});
+    return std::unique_ptr<framework::OpDesc>(op);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
+                  ops::Yolov3LossGradMaker);
+REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
+REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
+                       ops::Yolov3LossKernel<double>);
+REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel<float>,
+                       ops::Yolov3LossGradKernel<double>);
diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h
new file mode 100644
index 0000000000..0bb285722d
--- /dev/null
+++ b/paddle/fluid/operators/yolov3_loss_op.h
@@ -0,0 +1,483 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+   http://www.apache.org/licenses/LICENSE-2.0
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include <algorithm>
+#include <vector>
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template <typename T, size_t D, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
+
+using Array5 = Eigen::DSizes<int64_t, 5>;
+
+template <typename T>
+static inline bool isZero(T x) {
+  return fabs(x) < 1e-6;
+}
+
+template <typename T>
+static inline T sigmoid(T x) {
+  return 1.0 / (exp(-1.0 * x) + 1.0);
+}
+
+template <typename T>
+static inline T CalcMaskPointNum(const Tensor& mask) {
+  auto mask_t = EigenVector<int>::Flatten(mask);
+  T count = 0.0;
+  for (int i = 0; i < mask_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      count += 1.0;
+    }
+  }
+  return count;
+}
+
+template <typename T>
+static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y,
+                                const Tensor& mask) {
+  auto x_t = EigenVector<T>::Flatten(x);
+  auto y_t = EigenVector<T>::Flatten(y);
+  auto mask_t = EigenVector<int>::Flatten(mask);
+
+  T error_sum = 0.0;
+  T points = 0.0;
+  for (int i = 0; i < x_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      error_sum += pow(x_t(i) - y_t(i), 2);
+      points += 1;
+    }
+  }
+  return (error_sum / points);
+}
+
+template <typename T>
+static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y,
+                                const Tensor& mask, T mf) {
+  auto grad_t = EigenVector<T>::Flatten(*grad).setConstant(0.0);
+  auto x_t = EigenVector<T>::Flatten(x);
+  auto y_t = EigenVector<T>::Flatten(y);
+  auto mask_t = EigenVector<int>::Flatten(mask);
+
+  for (int i = 0; i < x_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf;
+    }
+  }
+}
+
+template <typename T>
+static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y,
+                                const Tensor& mask) {
+  auto x_t = EigenVector<T>::Flatten(x);
+  auto y_t = EigenVector<T>::Flatten(y);
+  auto mask_t = EigenVector<int>::Flatten(mask);
+
+  T error_sum = 0.0;
+  T points = 0.0;
+  for (int i = 0; i < x_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      error_sum +=
+          -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i)));
+      points += 1;
+    }
+  }
+  return (error_sum / points);
+}
+
+template <typename T>
+static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x,
+                                       const Tensor& y, const Tensor& mask,
+                                       T mf) {
+  auto grad_t = EigenVector<T>::Flatten(*grad).setConstant(0.0);
+  auto x_t = EigenVector<T>::Flatten(x);
+  auto y_t = EigenVector<T>::Flatten(y);
+  auto mask_t = EigenVector<int>::Flatten(mask);
+
+  for (int i = 0; i < x_t.dimensions()[0]; i++) {
+    if (mask_t(i)) {
+      grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf;
+    }
+  }
+}
+
+template <typename T>
+static void CalcPredResult(const Tensor& input, Tensor* pred_conf,
+                           Tensor* pred_class, Tensor* pred_x, Tensor* pred_y,
+                           Tensor* pred_w, Tensor* pred_h, const int anchor_num,
+                           const int class_num) {
+  const int n = input.dims()[0];
+  const int h = input.dims()[2];
+  const int w = input.dims()[3];
+  const int box_attr_num = 5 + class_num;
+
+  auto input_t = EigenTensor<T, 4>::From(input);
+  auto pred_conf_t = EigenTensor<T, 4>::From(*pred_conf);
+  auto pred_class_t = EigenTensor<T, 5>::From(*pred_class);
+  auto pred_x_t = EigenTensor<T, 4>::From(*pred_x);
+  auto pred_y_t = EigenTensor<T, 4>::From(*pred_y);
+  auto pred_w_t = EigenTensor<T, 4>::From(*pred_w);
+  auto pred_h_t = EigenTensor<T, 4>::From(*pred_h);
+
+  for (int i = 0; i < n; i++) {
+    for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
+      for (int j = 0; j < h; j++) {
+        for (int k = 0; k < w; k++) {
+          pred_x_t(i, an_idx, j, k) =
+              sigmoid(input_t(i, box_attr_num * an_idx, j, k));
+          pred_y_t(i, an_idx, j, k) =
+              sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k));
+          pred_w_t(i, an_idx, j, k) =
+              input_t(i, box_attr_num * an_idx + 2, j, k);
+          pred_h_t(i, an_idx, j, k) =
+              input_t(i, box_attr_num * an_idx + 3, j, k);
+
+          pred_conf_t(i, an_idx, j, k) =
+              sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k));
+
+          for (int c = 0; c < class_num; c++) {
+            pred_class_t(i, an_idx, j, k, c) =
+                sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k));
+          }
+        }
+      }
+    }
+  }
+}
+
+template <typename T>
+static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
+  T b1_x1 = box1[0] - box1[2] / 2;
+  T b1_x2 = box1[0] + box1[2] / 2;
+  T b1_y1 = box1[1] - box1[3] / 2;
+  T b1_y2 = box1[1] + box1[3] / 2;
+  T b2_x1 = box2[0] - box2[2] / 2;
+  T b2_x2 = box2[0] + box2[2] / 2;
+  T b2_y1 = box2[1] - box2[3] / 2;
+  T b2_y2 = box2[1] + box2[3] / 2;
+
+  T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1);
+  T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1);
+
+  T inter_rect_x1 = std::max(b1_x1, b2_x1);
+  T inter_rect_y1 = std::max(b1_y1, b2_y1);
+  T inter_rect_x2 = std::min(b1_x2, b2_x2);
+  T inter_rect_y2 = std::min(b1_y2, b2_y2);
+  T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast<T>(0.0)) *
+                 std::max(inter_rect_y2 - inter_rect_y1, static_cast<T>(0.0));
+
+  return inter_area / (b1_area + b2_area - inter_area);
+}
+
+template <typename T>
+static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
+                            const float ignore_thresh, std::vector<int> anchors,
+                            const int grid_size, Tensor* obj_mask,
+                            Tensor* noobj_mask, Tensor* tx, Tensor* ty,
+                            Tensor* tw, Tensor* th, Tensor* tconf,
+                            Tensor* tclass) {
+  const int n = gt_box.dims()[0];
+  const int b = gt_box.dims()[1];
+  const int anchor_num = anchors.size() / 2;
+  auto gt_box_t = EigenTensor<T, 3>::From(gt_box);
+  auto gt_label_t = EigenTensor<int, 2>::From(gt_label);
+  auto obj_mask_t = EigenTensor<int, 4>::From(*obj_mask).setConstant(0);
+  auto noobj_mask_t = EigenTensor<int, 4>::From(*noobj_mask).setConstant(1);
+  auto tx_t = EigenTensor<T, 4>::From(*tx).setConstant(0.0);
+  auto ty_t = EigenTensor<T, 4>::From(*ty).setConstant(0.0);
+  auto tw_t = EigenTensor<T, 4>::From(*tw).setConstant(0.0);
+  auto th_t = EigenTensor<T, 4>::From(*th).setConstant(0.0);
+  auto tconf_t = EigenTensor<T, 4>::From(*tconf).setConstant(0.0);
+  auto tclass_t = EigenTensor<T, 5>::From(*tclass).setConstant(0.0);
+
+  for (int i = 0; i < n; i++) {
+    for (int j = 0; j < b; j++) {
+      if (isZero<T>(gt_box_t(i, j, 0)) && isZero<T>(gt_box_t(i, j, 1)) &&
+          isZero<T>(gt_box_t(i, j, 2)) && isZero<T>(gt_box_t(i, j, 3))) {
+        continue;
+      }
+
+      int cur_label = gt_label_t(i, j);
+      T gx = gt_box_t(i, j, 0) * grid_size;
+      T gy = gt_box_t(i, j, 1) * grid_size;
+      T gw = gt_box_t(i, j, 2) * grid_size;
+      T gh = gt_box_t(i, j, 3) * grid_size;
+      int gi = static_cast<int>(gx);
+      int gj = static_cast<int>(gy);
+
+      T max_iou = static_cast<T>(0);
+      T iou;
+      int best_an_index = -1;
+      std::vector<T> gt_box_shape({0, 0, gw, gh});
+      for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
+        std::vector<T> anchor_shape({0, 0, static_cast<T>(anchors[2 * an_idx]),
+                                     static_cast<T>(anchors[2 * an_idx + 1])});
+        iou = CalcBoxIoU<T>(gt_box_shape, anchor_shape);
+        if (iou > max_iou) {
+          max_iou = iou;
+          best_an_index = an_idx;
+        }
+        if (iou > ignore_thresh) {
+          noobj_mask_t(i, an_idx, gj, gi) = 0;
+        }
+      }
+      obj_mask_t(i, best_an_index, gj, gi) = 1;
+      noobj_mask_t(i, best_an_index, gj, gi) = 0;
+      tx_t(i, best_an_index, gj, gi) = gx - gi;
+      ty_t(i, best_an_index, gj, gi) = gy - gj;
+      tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]);
+      th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]);
+      tclass_t(i, best_an_index, gj, gi, cur_label) = 1;
+      tconf_t(i, best_an_index, gj, gi) = 1;
+    }
+  }
+}
+
+static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand,
+                                    const Tensor& obj_mask) {
+  const int n = obj_mask_expand->dims()[0];
+  const int an_num = obj_mask_expand->dims()[1];
+  const int h = obj_mask_expand->dims()[2];
+  const int w = obj_mask_expand->dims()[3];
+  const int class_num = obj_mask_expand->dims()[4];
+  auto obj_mask_expand_t = EigenTensor<int, 5>::From(*obj_mask_expand);
+  auto obj_mask_t = EigenTensor<int, 4>::From(obj_mask);
+
+  obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1))
+                          .broadcast(Array5(1, 1, 1, 1, class_num));
+}
+
+template <typename T>
+static void AddAllGradToInputGrad(
+    Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y,
+    const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x,
+    const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h,
+    const Tensor& grad_conf_target, const Tensor& grad_conf_notarget,
+    const Tensor& grad_class, const int class_num, const float loss_weight_xy,
+    const float loss_weight_wh, const float loss_weight_conf_target,
+    const float loss_weight_conf_notarget, const float loss_weight_class) {
+  const int n = pred_x.dims()[0];
+  const int an_num = pred_x.dims()[1];
+  const int h = pred_x.dims()[2];
+  const int w = pred_x.dims()[3];
+  const int attr_num = class_num + 5;
+  auto grad_t = EigenTensor<T, 4>::From(*grad).setConstant(0.0);
+  auto pred_x_t = EigenTensor<T, 4>::From(pred_x);
+  auto pred_y_t = EigenTensor<T, 4>::From(pred_y);
+  auto pred_conf_t = EigenTensor<T, 4>::From(pred_conf);
+  auto pred_class_t = EigenTensor<T, 5>::From(pred_class);
+  auto grad_x_t = EigenTensor<T, 4>::From(grad_x);
+  auto grad_y_t = EigenTensor<T, 4>::From(grad_y);
+  auto grad_w_t = EigenTensor<T, 4>::From(grad_w);
+  auto grad_h_t = EigenTensor<T, 4>::From(grad_h);
+  auto grad_conf_target_t = EigenTensor<T, 4>::From(grad_conf_target);
+  auto grad_conf_notarget_t = EigenTensor<T, 4>::From(grad_conf_notarget);
+  auto grad_class_t = EigenTensor<T, 5>::From(grad_class);
+
+  for (int i = 0; i < n; i++) {
+    for (int j = 0; j < an_num; j++) {
+      for (int k = 0; k < h; k++) {
+        for (int l = 0; l < w; l++) {
+          grad_t(i, j * attr_num, k, l) =
+              grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) *
+              (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy;
+          grad_t(i, j * attr_num + 1, k, l) =
+              grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) *
+              (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy;
+          grad_t(i, j * attr_num + 2, k, l) =
+              grad_w_t(i, j, k, l) * loss * loss_weight_wh;
+          grad_t(i, j * attr_num + 3, k, l) =
+              grad_h_t(i, j, k, l) * loss * loss_weight_wh;
+          grad_t(i, j * attr_num + 4, k, l) =
+              grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
+              (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target;
+          grad_t(i, j * attr_num + 4, k, l) +=
+              grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
+              (1.0 - pred_conf_t(i, j, k, l)) * loss *
+              loss_weight_conf_notarget;
+
+          for (int c = 0; c < class_num; c++) {
+            grad_t(i, j * attr_num + 5 + c, k, l) =
+                grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) *
+                (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class;
+          }
+        }
+      }
+    }
+  }
+}
+
+template <typename T>
+class Yolov3LossKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* input = ctx.Input<Tensor>("X");
+    auto* gt_box = ctx.Input<Tensor>("GTBox");
+    auto* gt_label = ctx.Input<Tensor>("GTLabel");
+    auto* loss = ctx.Output<Tensor>("Loss");
+    auto anchors = ctx.Attr<std::vector<int>>("anchors");
+    int class_num = ctx.Attr<int>("class_num");
+    float ignore_thresh = ctx.Attr<float>("ignore_thresh");
+    float loss_weight_xy = ctx.Attr<float>("loss_weight_xy");
+    float loss_weight_wh = ctx.Attr<float>("loss_weight_wh");
+    float loss_weight_conf_target = ctx.Attr<float>("loss_weight_conf_target");
+    float loss_weight_conf_notarget =
+        ctx.Attr<float>("loss_weight_conf_notarget");
+    float loss_weight_class = ctx.Attr<float>("loss_weight_class");
+
+    const int n = input->dims()[0];
+    const int h = input->dims()[2];
+    const int w = input->dims()[3];
+    const int an_num = anchors.size() / 2;
+
+    Tensor pred_x, pred_y, pred_w, pred_h;
+    Tensor pred_conf, pred_class;
+    pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    CalcPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
+                      &pred_w, &pred_h, an_num, class_num);
+
+    Tensor obj_mask, noobj_mask;
+    Tensor tx, ty, tw, th, tconf, tclass;
+    obj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
+    noobj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
+    tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask,
+                       &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
+
+    Tensor obj_mask_expand;
+    obj_mask_expand.mutable_data<int>({n, an_num, h, w, class_num},
+                                      ctx.GetPlace());
+    ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask);
+
+    T loss_x = CalcMSEWithMask<T>(pred_x, tx, obj_mask);
+    T loss_y = CalcMSEWithMask<T>(pred_y, ty, obj_mask);
+    T loss_w = CalcMSEWithMask<T>(pred_w, tw, obj_mask);
+    T loss_h = CalcMSEWithMask<T>(pred_h, th, obj_mask);
+    T loss_conf_target = CalcBCEWithMask<T>(pred_conf, tconf, obj_mask);
+    T loss_conf_notarget = CalcBCEWithMask<T>(pred_conf, tconf, noobj_mask);
+    T loss_class = CalcBCEWithMask<T>(pred_class, tclass, obj_mask_expand);
+
+    auto* loss_data = loss->mutable_data<T>({1}, ctx.GetPlace());
+    loss_data[0] = loss_weight_xy * (loss_x + loss_y) +
+                   loss_weight_wh * (loss_w + loss_h) +
+                   loss_weight_conf_target * loss_conf_target +
+                   loss_weight_conf_notarget * loss_conf_notarget +
+                   loss_weight_class * loss_class;
+  }
+};
+
+template <typename T>
+class Yolov3LossGradKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* input = ctx.Input<Tensor>("X");
+    auto* gt_box = ctx.Input<Tensor>("GTBox");
+    auto* gt_label = ctx.Input<Tensor>("GTLabel");
+    auto anchors = ctx.Attr<std::vector<int>>("anchors");
+    int class_num = ctx.Attr<int>("class_num");
+    float ignore_thresh = ctx.Attr<float>("ignore_thresh");
+    auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
+    const T loss = output_grad->data<T>()[0];
+    float loss_weight_xy = ctx.Attr<float>("loss_weight_xy");
+    float loss_weight_wh = ctx.Attr<float>("loss_weight_wh");
+    float loss_weight_conf_target = ctx.Attr<float>("loss_weight_conf_target");
+    float loss_weight_conf_notarget =
+        ctx.Attr<float>("loss_weight_conf_notarget");
+    float loss_weight_class = ctx.Attr<float>("loss_weight_class");
+
+    const int n = input->dims()[0];
+    const int c = input->dims()[1];
+    const int h = input->dims()[2];
+    const int w = input->dims()[3];
+    const int an_num = anchors.size() / 2;
+
+    Tensor pred_x, pred_y, pred_w, pred_h;
+    Tensor pred_conf, pred_class;
+    pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    CalcPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
+                      &pred_w, &pred_h, an_num, class_num);
+
+    Tensor obj_mask, noobj_mask;
+    Tensor tx, ty, tw, th, tconf, tclass;
+    obj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
+    noobj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
+    tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask,
+                       &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
+
+    Tensor obj_mask_expand;
+    obj_mask_expand.mutable_data<int>({n, an_num, h, w, class_num},
+                                      ctx.GetPlace());
+    ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask);
+
+    Tensor grad_x, grad_y, grad_w, grad_h;
+    Tensor grad_conf_target, grad_conf_notarget, grad_class;
+    grad_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_conf_target.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_conf_notarget.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
+    grad_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
+    T obj_mf = CalcMaskPointNum<int>(obj_mask);
+    T noobj_mf = CalcMaskPointNum<int>(noobj_mask);
+    T obj_expand_mf = CalcMaskPointNum<int>(obj_mask_expand);
+    CalcMSEGradWithMask<T>(&grad_x, pred_x, tx, obj_mask, obj_mf);
+    CalcMSEGradWithMask<T>(&grad_y, pred_y, ty, obj_mask, obj_mf);
+    CalcMSEGradWithMask<T>(&grad_w, pred_w, tw, obj_mask, obj_mf);
+    CalcMSEGradWithMask<T>(&grad_h, pred_h, th, obj_mask, obj_mf);
+    CalcBCEGradWithMask<T>(&grad_conf_target, pred_conf, tconf, obj_mask,
+                           obj_mf);
+    CalcBCEGradWithMask<T>(&grad_conf_notarget, pred_conf, tconf, noobj_mask,
+                           noobj_mf);
+    CalcBCEGradWithMask<T>(&grad_class, pred_class, tclass, obj_mask_expand,
+                           obj_expand_mf);
+
+    input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
+    AddAllGradToInputGrad<T>(
+        input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y,
+        grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class,
+        class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target,
+        loss_weight_conf_notarget, loss_weight_class);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc
index d466f28d1e..f9a32bfa4c 100644
--- a/paddle/fluid/platform/cpu_info.cc
+++ b/paddle/fluid/platform/cpu_info.cc
@@ -123,7 +123,6 @@ size_t CUDAPinnedMaxChunkSize() {
   return CUDAPinnedMaxAllocSize() / 256;
 }
 
-namespace jit {
 #ifdef PADDLE_WITH_XBYAK
 static Xbyak::util::Cpu cpu;
 bool MayIUse(const cpu_isa_t cpu_isa) {
@@ -165,6 +164,5 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
 }
 #endif
 
-}  // namespace jit
 }  // namespace platform
 }  // namespace paddle
diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h
index fd31ef77b4..55dba545ff 100644
--- a/paddle/fluid/platform/cpu_info.h
+++ b/paddle/fluid/platform/cpu_info.h
@@ -39,7 +39,6 @@ size_t CUDAPinnedMinChunkSize();
 //! Get the maximum chunk size for buddy allocator.
 size_t CUDAPinnedMaxChunkSize();
 
-namespace jit {
 typedef enum {
   isa_any,
   sse42,
@@ -55,7 +54,5 @@ typedef enum {
 // May I use some instruction
 bool MayIUse(const cpu_isa_t cpu_isa);
 
-}  // namespace jit
-
 }  // namespace platform
 }  // namespace paddle
diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index d0a108f905..146a205832 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -120,15 +120,24 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
   }
 
   void* allocate(size_t num_bytes) const override {
+    if (UNLIKELY(num_bytes == 0)) {
+      return nullptr;
+    }
     auto buf = paddle::memory::Alloc(place_, num_bytes,
                                      memory::Allocator::kScratchpad);
     void* retv = buf->ptr();
-    allocations_[buf->ptr()] = std::move(buf);
+    {
+      std::lock_guard<std::mutex> lock(mtx_);
+      allocations_.emplace(retv, std::move(buf));
+    }
     return retv;
   }
 
   void deallocate(void* buffer) const override {
-    allocations_.erase(allocations_.find(buffer));
+    if (LIKELY(buffer)) {
+      std::lock_guard<std::mutex> lock(mtx_);
+      allocations_.erase(buffer);
+    }
   }
 
   void* scratchpad() const override {
@@ -155,6 +164,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
   const cudaDeviceProp* device_prop_;  // not owned;
   mutable void* scratch_;
   mutable unsigned int* semaphore_;
+  mutable std::mutex mtx_;  // to protect allocations_
   mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
 };
 
diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc
index dc1d751141..0a4563ead6 100644
--- a/paddle/fluid/platform/device_tracer.cc
+++ b/paddle/fluid/platform/device_tracer.cc
@@ -143,7 +143,7 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
           case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
             auto *kernel =
                 reinterpret_cast<const CUpti_ActivityKernel3 *>(record);
-            tracer->AddKernelRecords(kernel->start, kernel->end,
+            tracer->AddKernelRecords(kernel->name, kernel->start, kernel->end,
                                      kernel->deviceId, kernel->streamId,
                                      kernel->correlationId);
             break;
@@ -224,8 +224,9 @@ class DeviceTracerImpl : public DeviceTracer {
                                      stream_id, correlation_id, bytes});
   }
 
-  void AddKernelRecords(uint64_t start, uint64_t end, int64_t device_id,
-                        int64_t stream_id, uint32_t correlation_id) {
+  void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
+                        int64_t device_id, int64_t stream_id,
+                        uint32_t correlation_id) {
     // 0 means timestamp information could not be collected for the kernel.
     if (start == 0 || end == 0) {
       VLOG(3) << correlation_id << " cannot be traced";
@@ -233,7 +234,7 @@ class DeviceTracerImpl : public DeviceTracer {
     }
     std::lock_guard<std::mutex> l(trace_mu_);
     kernel_records_.push_back(
-        KernelRecord{start, end, device_id, stream_id, correlation_id});
+        KernelRecord{name, start, end, device_id, stream_id, correlation_id});
   }
 
   bool IsEnabled() {
@@ -276,13 +277,13 @@ class DeviceTracerImpl : public DeviceTracer {
     profile_pb.set_start_ns(start_ns_);
     profile_pb.set_end_ns(end_ns_);
     for (const KernelRecord &r : kernel_records_) {
-      if (correlations_.find(r.correlation_id) == correlations_.end()) {
-        fprintf(stderr, "cannot relate a kernel activity\n");
-        continue;
-      }
       auto *event = profile_pb.add_events();
       event->set_type(proto::Event::GPUKernel);
-      event->set_name(correlations_.at(r.correlation_id));
+      if (correlations_.find(r.correlation_id) != correlations_.end()) {
+        event->set_name(correlations_.at(r.correlation_id));
+      } else {
+        event->set_name(r.name);
+      }
       event->set_start_ns(r.start_ns);
       event->set_end_ns(r.end_ns);
       event->set_sub_device_id(r.stream_id);
diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h
index eaf047d474..bf0786be2d 100644
--- a/paddle/fluid/platform/device_tracer.h
+++ b/paddle/fluid/platform/device_tracer.h
@@ -39,6 +39,7 @@ inline uint64_t PosixInNsec() {
 class DeviceTracer {
  public:
   struct KernelRecord {
+    std::string name;
     uint64_t start_ns;
     uint64_t end_ns;
     int64_t device_id;
@@ -84,8 +85,9 @@ class DeviceTracer {
 
   // Add a cuda kernel stats. `correlation_id` will be mapped to annotation
   // added before for human readability.
-  virtual void AddKernelRecords(uint64_t start, uint64_t end, int64_t device_id,
-                                int64_t stream_id, uint32_t correlation_id) = 0;
+  virtual void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
+                                int64_t device_id, int64_t stream_id,
+                                uint32_t correlation_id) = 0;
 
   // Generate a proto after done (Disabled).
   virtual proto::Profile GenProfile(const std::string& profile_path) = 0;
diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h
index 213cd8a9ce..550fe2edee 100644
--- a/paddle/fluid/platform/dynload/cudnn.h
+++ b/paddle/fluid/platform/dynload/cudnn.h
@@ -125,8 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
   __macro(cudnnRNNBackwardWeights);                       \
   __macro(cudnnRNNForwardInference);                      \
   __macro(cudnnDestroyDropoutDescriptor);                 \
-  __macro(cudnnDestroyRNNDescriptor);                     \
-  __macro(cudnnSetRNNDescriptor_v6);
+  __macro(cudnnDestroyRNNDescriptor);
 
 CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
 
@@ -165,6 +164,12 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
 CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
 #endif
 
+// APIs in R6
+#if CUDNN_VERSION >= 6000
+#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) __macro(cudnnSetRNNDescriptor_v6);
+CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
+#endif
+
 #if CUDNN_VERSION >= 7001
 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro)        \
   __macro(cudnnSetConvolutionGroupCount);         \
diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc
index 6954e4c6a9..ca89d91aad 100644
--- a/paddle/fluid/platform/gpu_info.cc
+++ b/paddle/fluid/platform/gpu_info.cc
@@ -18,6 +18,7 @@ limitations under the License. */
 
 #include "gflags/gflags.h"
 #include "paddle/fluid/platform/enforce.h"
+#include "paddle/fluid/string/split.h"
 
 #ifndef _WIN32
 constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
@@ -45,6 +46,15 @@ DEFINE_bool(
     "input and output must be half precision) and recurrent neural networks "
     "(RNNs).");
 
+DEFINE_string(selected_gpus, "",
+              "A list of device ids separated by comma, like: 0,1,2,3. "
+              "This option is useful when doing multi process training and "
+              "each process have only one device (GPU). If you want to use "
+              "all visible devices, set this to empty string. NOTE: the "
+              "reason of doing this is that we want to use P2P communication"
+              "between GPU devices, use CUDA_VISIBLE_DEVICES can only use"
+              "share-memory only.");
+
 namespace paddle {
 namespace platform {
 
@@ -121,6 +131,24 @@ int GetCurrentDeviceId() {
   return device_id;
 }
 
+//! Get a list of device ids from environment variable or use all.
+std::vector<int> GetSelectedDevices() {
+  // use user specified GPUs in single-node multi-process mode.
+  std::vector<int> devices;
+  if (!FLAGS_selected_gpus.empty()) {
+    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
+    for (auto id : devices_str) {
+      devices.push_back(atoi(id.c_str()));
+    }
+  } else {
+    int count = GetCUDADeviceCount();
+    for (int i = 0; i < count; ++i) {
+      devices.push_back(i);
+    }
+  }
+  return devices;
+}
+
 void SetDeviceId(int id) {
   // TODO(qijun): find a better way to cache the cuda device count
   PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h
index 6a0b3c8e02..1e1ab2503f 100644
--- a/paddle/fluid/platform/gpu_info.h
+++ b/paddle/fluid/platform/gpu_info.h
@@ -19,6 +19,7 @@ limitations under the License. */
 #include <cuda_runtime.h>
 #include <stddef.h>
 #include <string>
+#include <vector>
 
 namespace paddle {
 namespace platform {
@@ -47,6 +48,9 @@ int GetCUDAMaxThreadsPerMultiProcessor(int i);
 //! Get the current GPU device id in system.
 int GetCurrentDeviceId();
 
+//! Get a list of device ids from environment variable or use all.
+std::vector<int> GetSelectedDevices();
+
 //! Set the GPU device id for next execution.
 void SetDeviceId(int device_id);
 
diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc
index 258779ba51..0d10d82d74 100644
--- a/paddle/fluid/platform/init.cc
+++ b/paddle/fluid/platform/init.cc
@@ -19,6 +19,7 @@ limitations under the License. */
 #include "paddle/fluid/framework/operator.h"
 #include "paddle/fluid/platform/cpu_helper.h"
 #include "paddle/fluid/platform/cpu_info.h"
+#include "paddle/fluid/string/split.h"
 #ifdef PADDLE_WITH_CUDA
 #include "paddle/fluid/platform/cuda_device_guard.h"
 #endif
@@ -82,10 +83,8 @@ void InitDevices(bool init_p2p) {
   std::vector<int> devices;
 #ifdef PADDLE_WITH_CUDA
   try {
-    int count = platform::GetCUDADeviceCount();
-    for (int i = 0; i < count; ++i) {
-      devices.push_back(i);
-    }
+    // use user specified GPUs in single-node multi-process mode.
+    devices = platform::GetSelectedDevices();
   } catch (const std::exception &exp) {
     LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
   }
@@ -95,20 +94,15 @@ void InitDevices(bool init_p2p) {
 
 void InitDevices(bool init_p2p, const std::vector<int> devices) {
   std::vector<platform::Place> places;
-  int count = 0;
-#ifdef PADDLE_WITH_CUDA
-  try {
-    count = platform::GetCUDADeviceCount();
-  } catch (const std::exception &exp) {
-    LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
-  }
-#endif
 
   for (size_t i = 0; i < devices.size(); ++i) {
-    if (devices[i] >= count || devices[i] < 0) {
+    // In multi process multi gpu mode, we may have gpuid = 7
+    // but count = 1.
+    if (devices[i] < 0) {
       LOG(WARNING) << "Invalid devices id.";
       continue;
     }
+
     places.emplace_back(platform::CUDAPlace(devices[i]));
   }
   if (init_p2p) {
@@ -122,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
 #endif
 
 #if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__)
-  if (platform::jit::MayIUse(platform::jit::avx)) {
+  if (platform::MayIUse(platform::avx)) {
 #ifndef __AVX__
     LOG(WARNING) << "AVX is available, Please re-compile on local machine";
 #endif
@@ -137,10 +131,10 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
          " version or compile from source code."
 
 #ifdef __AVX512F__
-  if (!platform::jit::MayIUse(platform::jit::avx512f)) {
-    if (platform::jit::MayIUse(platform::jit::avx2)) {
+  if (!platform::MayIUse(platform::avx512f)) {
+    if (platform::MayIUse(platform::avx2)) {
       AVX_GUIDE(AVX512, AVX2);
-    } else if (platform::jit::MayIUse(platform::jit::avx)) {
+    } else if (platform::MayIUse(platform::avx)) {
       AVX_GUIDE(AVX512, AVX);
     } else {
       AVX_GUIDE(AVX512, NonAVX);
@@ -149,8 +143,8 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
 #endif
 
 #ifdef __AVX2__
-  if (!platform::jit::MayIUse(platform::jit::avx2)) {
-    if (platform::jit::MayIUse(platform::jit::avx)) {
+  if (!platform::MayIUse(platform::avx2)) {
+    if (platform::MayIUse(platform::avx)) {
       AVX_GUIDE(AVX2, AVX);
     } else {
       AVX_GUIDE(AVX2, NonAVX);
@@ -159,7 +153,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
 #endif
 
 #ifdef __AVX__
-  if (!platform::jit::MayIUse(platform::jit::avx)) {
+  if (!platform::MayIUse(platform::avx)) {
     AVX_GUIDE(AVX, NonAVX);
   }
 #endif
diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h
index 167bd4e81d..e53064893e 100644
--- a/paddle/fluid/platform/mkldnn_helper.h
+++ b/paddle/fluid/platform/mkldnn_helper.h
@@ -113,6 +113,18 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
     return mkldnn::memory::format::x;
   } else if (dims_size == 2) {
     return mkldnn::memory::format::nc;
+  } else if (dims_size == 3) {
+    if (data_format == mkldnn::memory::format::nchw) {
+      return mkldnn::memory::format::ncw;
+    } else if (data_format == mkldnn::memory::format::nhwc) {
+      return mkldnn::memory::format::nwc;
+    }
+  } else if (dims_size == 5) {
+    if (data_format == mkldnn::memory::format::nchw) {
+      return mkldnn::memory::format::ncdhw;
+    } else if (data_format == mkldnn::memory::format::nhwc) {
+      return mkldnn::memory::format::ndhwc;
+    }
   }
   return data_format;
 }
diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h
index fc903b548c..7c539d25f6 100644
--- a/paddle/fluid/platform/nccl_helper.h
+++ b/paddle/fluid/platform/nccl_helper.h
@@ -97,7 +97,7 @@ struct NCCLContextMap {
         order_.size(), contexts_.size(),
         "NCCL Context Map does not support contain two or more same device");
 
-    if (places.size() <= 1) {
+    if (places.size() <= 1 && num_trainers == 1) {
       return;
     }
     std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
@@ -111,12 +111,19 @@ struct NCCLContextMap {
       {
         int nranks = num_trainers * order_.size();
         NCCLGroupGuard gurad;
-        for (auto &gpu_id : order_) {
-          int rank = trainer_id * order_.size() + gpu_id;
-          VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
+        for (size_t i = 0; i < order_.size(); ++i) {
+          int gpu_id = order_[i];
+          int rank;
+          if (order_.size() > 1) {
+            rank = trainer_id * order_.size() + i;
+          } else {
+            rank = trainer_id;
+          }
+          VLOG(30) << "init nccl rank: " << rank << " nranks: " << nranks
+                   << "gpu id: " << gpu_id;
           PADDLE_ENFORCE(cudaSetDevice(gpu_id));
           PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
-              comms.get() + gpu_id, nranks, *nccl_id, rank));
+              comms.get() + i, nranks, *nccl_id, rank));
         }
       }
     }
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index fc7991d297..58ef3da0b2 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -886,6 +886,18 @@ All parameter, weight, gradient are variables in Paddle.
           [](BuildStrategy &self, int num_trainers) {
             self.num_trainers_ = num_trainers;
           })
+      .def_property(
+          "trainers_endpoints",
+          [](const BuildStrategy &self) { return self.trainers_endpoints_; },
+          [](BuildStrategy &self,
+             const std::vector<std::string> &trainers_endpoints) {
+            self.trainers_endpoints_ = trainers_endpoints;
+          })
+      .def_property("trainer_id",
+                    [](const BuildStrategy &self) { return self.trainer_id_; },
+                    [](BuildStrategy &self, int trainer_id) {
+                      self.trainer_id_ = trainer_id;
+                    })
       .def_property(
           "fuse_elewise_add_act_ops",
           [](const BuildStrategy &self) {
diff --git a/paddle/fluid/string/CMakeLists.txt b/paddle/fluid/string/CMakeLists.txt
index 8572dc1e8e..169a925d12 100644
--- a/paddle/fluid/string/CMakeLists.txt
+++ b/paddle/fluid/string/CMakeLists.txt
@@ -3,3 +3,4 @@ cc_library(pretty_log SRCS pretty_log.cc)
 cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
 cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
 cc_test(to_string_test SRCS to_string_test.cc)
+cc_test(split_test SRCS split_test.cc)
diff --git a/paddle/fluid/string/split.h b/paddle/fluid/string/split.h
new file mode 100644
index 0000000000..ccb96b8a9c
--- /dev/null
+++ b/paddle/fluid/string/split.h
@@ -0,0 +1,37 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include <sstream>
+#include <string>
+#include <vector>
+
+namespace paddle {
+namespace string {
+
+static inline std::vector<std::string> Split(std::string const& original,
+                                             char separator) {
+  std::vector<std::string> results;
+  std::string token;
+  std::istringstream is(original);
+  while (std::getline(is, token, separator)) {
+    if (!token.empty()) {
+      results.push_back(token);
+    }
+  }
+  return results;
+}
+
+}  // namespace string
+}  // namespace paddle
diff --git a/paddle/fluid/string/split_test.cc b/paddle/fluid/string/split_test.cc
new file mode 100644
index 0000000000..c85dc1eed4
--- /dev/null
+++ b/paddle/fluid/string/split_test.cc
@@ -0,0 +1,28 @@
+//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/string/split.h"
+
+#include <string>
+
+#include "gtest/gtest.h"
+
+TEST(StringSplit, StringSplit) {
+  std::string to_split = "0,1,2,3,4,5";
+  int i = 0;
+  for (auto s : paddle::string::Split(to_split, ',')) {
+    EXPECT_EQ(atoi(s.c_str()), i);
+    i++;
+  }
+}
diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh
index dbb73f7a27..6299b166af 100755
--- a/paddle/scripts/paddle_build.sh
+++ b/paddle/scripts/paddle_build.sh
@@ -437,7 +437,7 @@ EOF
         export http_proxy=
         export https_proxy=
         # TODO: jiabin need to refine this part when these tests fixed on mac
-        ctest --output-on-failure -j $1
+        ctest --output-on-failure -j $2
         # make install should also be test when unittest
         make install -j 8
         if [ "$1" == "cp27-cp27m" ]; then
@@ -449,7 +449,7 @@ EOF
         elif [ "$1" == "cp37-cp37m" ]; then
             pip3.7 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl
         fi
-
+      
         if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]] ; then
             paddle version
         fi
@@ -472,12 +472,15 @@ function assert_api_not_changed() {
     virtualenv .env
     source .env/bin/activate
     pip install ${PADDLE_ROOT}/build/python/dist/*whl
-    python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid > new.spec
+    python ${PADDLE_ROOT}/tools/print_signatures.py paddle.fluid,paddle.reader > new.spec
     if [ "$1" == "cp35-cp35m" ] || [ "$1" == "cp36-cp36m" ] || [ "$1" == "cp37-cp37m" ]; then
         # Use sed to make python2 and python3 sepc keeps the same
         sed -i 's/arg0: str/arg0: unicode/g' new.spec
         sed -i "s/\(.*Transpiler.*\).__init__ ArgSpec(args=\['self'].*/\1.__init__ /g" new.spec
     fi
+    # ComposeNotAligned has significant difference between py2 and py3
+    sed -i '/.*ComposeNotAligned.*/d' new.spec
+
     python ${PADDLE_ROOT}/tools/diff_api.py ${PADDLE_ROOT}/paddle/fluid/API.spec new.spec
     deactivate
 }
@@ -487,7 +490,19 @@ function assert_api_spec_approvals() {
         BRANCH="develop"
     fi
 
-    API_FILES=("paddle/fluid/API.spec" "paddle/fluid/framework/operator.h")
+    API_FILES=("paddle/fluid/API.spec"
+               "paddle/fluid/framework/operator.h"
+               "paddle/fluid/framework/tensor.h"
+               "paddle/fluid/framework/lod_tensor.h"
+               "paddle/fluid/framework/selected_rows.h"
+               "paddle/fluid/framework/op_desc.h"
+               "paddle/fluid/framework/block_desc.h"
+               "paddle/fluid/framework/var_desc.h"
+               "paddle/fluid/framework/scope.h"
+               "paddle/fluid/framework/ir/node.h"
+               "paddle/fluid/framework/ir/graph.h"
+               "paddle/fluid/framework/framework.proto"
+               "paddle/fluid/operators/distributed/send_recv.proto.in")
     for API_FILE in ${API_FILES[*]}; do
       API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "${API_FILE}" || true`
       echo "checking ${API_FILE} change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}"
@@ -901,7 +916,7 @@ function main() {
       maccheck)
         cmake_gen ${PYTHON_ABI:-""}
         build_mac
-        run_mac_test ${PROC_RUN:-1}
+        run_mac_test ${PYTHON_ABI:-""} ${PROC_RUN:-1}
         ;;
       macbuild)
         cmake_gen ${PYTHON_ABI:-""}
diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py
index a1ffbf4262..2a53519188 100644
--- a/python/paddle/fluid/__init__.py
+++ b/python/paddle/fluid/__init__.py
@@ -147,7 +147,7 @@ def __bootstrap__():
         read_env_flags += [
             'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
             'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
-            'cudnn_exhaustive_search'
+            'cudnn_exhaustive_search', 'selected_gpus'
         ]
     core.init_gflags([sys.argv[0]] +
                      ["--tryfromenv=" + ",".join(read_env_flags)])
diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py
index 1738afe93e..0f7dd531b3 100644
--- a/python/paddle/fluid/clip.py
+++ b/python/paddle/fluid/clip.py
@@ -134,12 +134,12 @@ class GradientClipByValue(BaseGradientClipAttr):
     Examples:
         .. code-block:: python
 
-            w_param_attrs = ParamAttr(name=None,
-              initializer=UniformInitializer(low=-1.0, high=1.0, seed=0),
+            w_param_attrs = fluid.ParamAttr(name=None,
+              initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0),
               learning_rate=1.0,
-              regularizer=L1Decay(1.0),
+              regularizer=fluid.regularizer.L1Decay(1.0),
               trainable=True,
-              clip=GradientClipByValue(-1.0, 1.0))
+              clip=fluid.clip.GradientClipByValue(-1.0, 1.0))
             y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
     """
 
@@ -185,12 +185,12 @@ class GradientClipByNorm(BaseGradientClipAttr):
     Examples:
         .. code-block:: python
 
-            w_param_attrs = ParamAttr(name=None,
-              initializer=UniformInitializer(low=-1.0, high=1.0, seed=0),
+            w_param_attrs = flui.ParamAttr(name=None,
+              initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0),
               learning_rate=1.0,
-              regularizer=L1Decay(1.0),
+              regularizer=fluid.regularizer.L1Decay(1.0),
               trainable=True,
-              clip=GradientClipByNorm(clip_norm=2.0))
+              clip=fluid.clip.GradientClipByNorm(clip_norm=2.0))
             y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
 
     """
@@ -271,7 +271,12 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
                     "All parameters' 'clip_norm' of a same group should be the same"
                 )
 
-        square = grad * grad
+        merge_grad = grad
+        if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
+            merge_grad = layers.merge_selected_rows(grad)
+            merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
+
+        square = layers.square(merge_grad)
         local_norm_var = layers.reduce_sum(input=square)
         context[self.group_name].append(local_norm_var)
 
@@ -292,6 +297,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
 
         new_grad = layers.elementwise_mul(
             x=grad, y=self.context[group_scale_name])
+
         return param, new_grad
 
 
diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py
index 5102a558fd..13d2893fd1 100644
--- a/python/paddle/fluid/data_feeder.py
+++ b/python/paddle/fluid/data_feeder.py
@@ -258,10 +258,13 @@ class DataFeeder(object):
         multiple mini-batches. Each mini-batch will be feed on each device.
 
         Args:
-            reader(fun): the input data.
-            multi_devices(bool): the number of places. Default None.
-            num_places(int): the number of places. Default None.
-            drop_last(bool): the number of places. Default None.
+            reader(function): the reader is the function which can generate data.
+            multi_devices(bool): whether to use multiple devices or not.
+            num_places(int): if the multi_devices is True, you can specify the number
+                of GPU to use, if 'num_places' is None, the function will use all the
+                GPU of the current machine. Default None.
+            drop_last(bool): whether to drop the last batch if the
+                size of the last batch is less than batch_size. Default True.
 
         Returns:
             dict: the result of conversion.
diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py
index 42c2484b28..f2886090d7 100644
--- a/python/paddle/fluid/executor.py
+++ b/python/paddle/fluid/executor.py
@@ -20,7 +20,7 @@ import six
 from .framework import Program, default_main_program, Variable
 from . import core
 
-__all__ = ['Executor', 'global_scope', 'scope_guard', '_switch_scope']
+__all__ = ['Executor', 'global_scope', 'scope_guard']
 
 g_scope = core.Scope()
 
@@ -407,16 +407,17 @@ class Executor(object):
 
         Examples:
 
-            >>> data = layers.data(name='X', shape=[1], dtype='float32')
-            >>> hidden = layers.fc(input=data, size=10)
-            >>> layers.assign(hidden, out)
-            >>> loss = layers.mean(out)
+            >>> data = fluid.layers.data(name='X', shape=[1], dtype='float32')
+            >>> out = fluid.layers.create_tensor(dtype='float32')
+            >>> hidden = fluid.layers.fc(input=data, size=10)
+            >>> fluid.layers.assign(hidden,out)
+            >>> loss = fluid.layers.mean(out)
             >>> adam = fluid.optimizer.Adam()
-            >>> adam.minimize(loss)
+						>>> adam.minimize(loss)
 
             >>> cpu = core.CPUPlace()
-            >>> exe = Executor(cpu)
-            >>> exe.run(default_startup_program())
+            >>> exe = fluid.Executor(cpu)
+            >>> exe.run(fluid.default_startup_program())
 
             >>> x = numpy.random.random(size=(10, 1)).astype('float32')
             >>> outs = exe.run(
diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py
index b991187d42..a40826168d 100644
--- a/python/paddle/fluid/framework.py
+++ b/python/paddle/fluid/framework.py
@@ -89,12 +89,13 @@ def name_scope(prefix=None):
 
     Examples:
         .. code-block:: python
+
           with name_scope("encoder"):
              ...
           with name_scope("decoder"):
              ...
-             with name_scope("attention"):
-                ...
+          with name_scope("attention"):
+             ...
     """
     # TODO(panyx0718): Only [0-9a-z].
     assert prefix, "namescope prefix cannot be empty."
@@ -1441,6 +1442,7 @@ class Program(object):
         self._is_chief = False
         self._slice_vars_and_attrs = []
         self._endpoints = []
+        self._trainers_endpoints = []
         self._distributed_lookup_table = None
 
     @property
diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py
index 0782933c6c..e74a87fc68 100644
--- a/python/paddle/fluid/io.py
+++ b/python/paddle/fluid/io.py
@@ -145,7 +145,7 @@ def save_vars(executor,
 
             prog = fluid.default_main_program()
             fluid.io.save_vars(executor=exe, dirname=path, main_program=prog,
-                               vars=None)
+                               vars=None, predicate = name_has_fc)
             # All variables in `main_program` whose name includes "fc" will be saved.
             # And variables are going to be saved separately.
 
@@ -369,7 +369,7 @@ def load_vars(executor,
 
             prog = fluid.default_main_program()
             fluid.io.load_vars(executor=exe, dirname=path, main_program=prog,
-                               vars=None)
+                               vars=None, predicate=name_has_fc)
             # All variables in `main_program` whose name includes "fc" will be loaded.
             # And all the variables are supposed to have been saved in differnet files.
 
diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py
index 4843af8340..ce731f39ea 100644
--- a/python/paddle/fluid/layers/detection.py
+++ b/python/paddle/fluid/layers/detection.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 from .layer_function_generator import generate_layer_fn
 from .layer_function_generator import autodoc, templatedoc
 from ..layer_helper import LayerHelper
+from ..framework import Variable
 from . import tensor
 from . import nn
 from . import ops
@@ -46,6 +47,7 @@ __all__ = [
     'iou_similarity',
     'box_coder',
     'polygon_box_transform',
+    'yolov3_loss',
 ]
 
 
@@ -401,6 +403,113 @@ def polygon_box_transform(input, name=None):
     return output
 
 
+@templatedoc(op_type="yolov3_loss")
+def yolov3_loss(x,
+                gtbox,
+                gtlabel,
+                anchors,
+                class_num,
+                ignore_thresh,
+                loss_weight_xy=None,
+                loss_weight_wh=None,
+                loss_weight_conf_target=None,
+                loss_weight_conf_notarget=None,
+                loss_weight_class=None,
+                name=None):
+    """
+    ${comment}
+
+    Args:
+        x (Variable): ${x_comment}
+        gtbox (Variable): groud truth boxes, should be in shape of [N, B, 4],
+                          in the third dimenstion, x, y, w, h should be stored 
+                          and x, y, w, h should be relative value of input image.
+                          N is the batch number and B is the max box number in 
+                          an image.
+        gtlabel (Variable): class id of ground truth boxes, shoud be ins shape
+                            of [N, B].
+        anchors (list|tuple): ${anchors_comment}
+        class_num (int): ${class_num_comment}
+        ignore_thresh (float): ${ignore_thresh_comment}
+        loss_weight_xy (float|None): ${loss_weight_xy_comment}
+        loss_weight_wh (float|None): ${loss_weight_wh_comment}
+        loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment}
+        loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment}
+        loss_weight_class (float|None): ${loss_weight_class_comment}
+        name (string): the name of yolov3 loss
+
+    Returns:
+        Variable: A 1-D tensor with shape [1], the value of yolov3 loss
+
+    Raises:
+        TypeError: Input x of yolov3_loss must be Variable
+        TypeError: Input gtbox of yolov3_loss must be Variable"
+        TypeError: Input gtlabel of yolov3_loss must be Variable"
+        TypeError: Attr anchors of yolov3_loss must be list or tuple
+        TypeError: Attr class_num of yolov3_loss must be an integer
+        TypeError: Attr ignore_thresh of yolov3_loss must be a float number
+
+    Examples:
+    .. code-block:: python
+
+        x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
+        gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32')
+        gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32')
+        anchors = [10, 13, 16, 30, 33, 23]
+        loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80
+                                        anchors=anchors, ignore_thresh=0.5)
+    """
+    helper = LayerHelper('yolov3_loss', **locals())
+
+    if not isinstance(x, Variable):
+        raise TypeError("Input x of yolov3_loss must be Variable")
+    if not isinstance(gtbox, Variable):
+        raise TypeError("Input gtbox of yolov3_loss must be Variable")
+    if not isinstance(gtlabel, Variable):
+        raise TypeError("Input gtlabel of yolov3_loss must be Variable")
+    if not isinstance(anchors, list) and not isinstance(anchors, tuple):
+        raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
+    if not isinstance(class_num, int):
+        raise TypeError("Attr class_num of yolov3_loss must be an integer")
+    if not isinstance(ignore_thresh, float):
+        raise TypeError(
+            "Attr ignore_thresh of yolov3_loss must be a float number")
+
+    if name is None:
+        loss = helper.create_variable_for_type_inference(dtype=x.dtype)
+    else:
+        loss = helper.create_variable(
+            name=name, dtype=x.dtype, persistable=False)
+
+    attrs = {
+        "anchors": anchors,
+        "class_num": class_num,
+        "ignore_thresh": ignore_thresh,
+    }
+
+    if loss_weight_xy is not None and isinstance(loss_weight_xy, float):
+        self.attrs['loss_weight_xy'] = loss_weight_xy
+    if loss_weight_wh is not None and isinstance(loss_weight_wh, float):
+        self.attrs['loss_weight_wh'] = loss_weight_wh
+    if loss_weight_conf_target is not None and isinstance(
+            loss_weight_conf_target, float):
+        self.attrs['loss_weight_conf_target'] = loss_weight_conf_target
+    if loss_weight_conf_notarget is not None and isinstance(
+            loss_weight_conf_notarget, float):
+        self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget
+    if loss_weight_class is not None and isinstance(loss_weight_class, float):
+        self.attrs['loss_weight_class'] = loss_weight_class
+
+    helper.append_op(
+        type='yolov3_loss',
+        inputs={"X": x,
+                "GTBox": gtbox,
+                "GTLabel": gtlabel},
+        outputs={'Loss': loss},
+        attrs=attrs)
+    return loss
+
+
 @templatedoc()
 def detection_map(detect_res,
                   label,
diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py
index 3f47053961..42f4959a83 100644
--- a/python/paddle/fluid/layers/io.py
+++ b/python/paddle/fluid/layers/io.py
@@ -943,7 +943,18 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
 
 def shuffle(reader, buffer_size):
     """
-    Shuffle the reader.
+    Creates a data reader whose data output is shuffled.
+    Output from the iterator that created by original reader will be
+    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
+    is determined by argument buf_size.
+
+    Args:
+        param reader: the original reader whose output will be shuffled.
+        type reader: callable
+        param buf_size: shuffle buffer size.
+        type buf_size: int
+        return: the new reader whose output is shuffled.
+        rtype: callable
     """
     return __create_unshared_decorated_reader__(
         'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py
index eea0a362a0..09b1b30216 100644
--- a/python/paddle/fluid/layers/layer_function_generator.py
+++ b/python/paddle/fluid/layers/layer_function_generator.py
@@ -20,7 +20,7 @@ import string
 
 from six.moves import cStringIO
 from ..proto import framework_pb2
-from ..framework import OpProtoHolder, Variable
+from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
 from ..layer_helper import LayerHelper
 
 __all__ = [
@@ -178,6 +178,15 @@ def generate_layer_fn(op_type):
                         "operator {0} must input same dtype. {1} vs {2}".format(
                             op_type, dtype, each.dtype))
 
+        if dtype is None:
+            arg_dtype = kwargs.get("dtype")
+            if arg_dtype:
+                if not isinstance(arg_dtype, core.VarDesc.VarType):
+                    dtype = convert_np_dtype_to_dtype_(arg_dtype)
+                else:
+                    dtype = arg_dtype
+            else:
+                dtype = core.VarDesc.VarType.FP32
         return dtype
 
     def func(*args, **kwargs):
diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py
index 149224bb68..dde0518972 100644
--- a/python/paddle/fluid/layers/learning_rate_scheduler.py
+++ b/python/paddle/fluid/layers/learning_rate_scheduler.py
@@ -308,13 +308,9 @@ def piecewise_decay(boundaries, values):
 
 
 def append_LARS(params_grads, learning_rate, weight_decay):
-    """Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for
-       each layer.
-
-    ```python
-        learning_rate *= local_gw_ratio * sqrt(sumsq(param))
-                        / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param)))
-    ```
+    """
+    Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for
+    each layer.
 
     Args:
         learning_rate: A learning rate Variable. This
@@ -323,6 +319,11 @@ def append_LARS(params_grads, learning_rate, weight_decay):
 
     Returns:
         The decayed learning rate
+    Examples:
+        .. code-block:: python
+        
+            learning_rate *= local_gw_ratio * sqrt(sumsq(param))
+                        / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param)))
     """
 
     def _balanced_weight(param_norm, grad_norm):
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index ac1c865775..4833212d31 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -169,6 +169,8 @@ __all__ = [
     'log_loss',
     'add_position_encoding',
     'bilinear_tensor_product',
+    'merge_selected_rows',
+    'get_tensor_from_selected_rows',
     'lstm',
 ]
 
@@ -928,7 +930,7 @@ def dynamic_gru(input,
             emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
             hidden_dim = 512
             x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
-            hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim)
+            hidden = fluid.layers.dynamic_gru(input=x, size=hidden_dim)
     """
 
     helper = LayerHelper('gru', **locals())
@@ -3586,6 +3588,7 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None):
 
     Examples:
         .. code-block:: python
+
             # Suppose `ids` and `scores` are LodTensorArray variables reserving
             # the selected ids and scores of all steps
             finished_ids, finished_scores = layers.beam_search_decode(
@@ -5083,7 +5086,7 @@ def im2sequence(input,
 
             output.lod = [[4, 4]]
 
-     Examples:
+    Examples:
 
         .. code-block:: python
 
@@ -5870,24 +5873,23 @@ def pad_constant_like(x, y, pad_value=0., name=None):
                   [[38, 39, 40]],
                   [[41, 42, 43]]]]
             Y.shape = (1, 3, 1, 3)
+		And
+            pad_value = -1,
 
-    And
-        pad_value = -1,
-
-    Return:
-        Out = [[[[35, 36, 37],
-                  [-1, -1, -1]],
-                [[38, 39, 40],
-                  [-1, -1, -1]],
-                 [[41, 42, 43],
-                  [-1, -1, -1]]],
-                [[[-1, -1, -1],
-                  [-1, -1, -1]],
-                 [[-1, -1, -1],
-                  [-1, -1, -1]],
-                 [[-1, -1, -1],
-                  [-1, -1, -1]]]]
-        Out.shape = (2, 3, 2, 3)
+        Return:
+            Out = [[[[35, 36, 37],
+                     [-1, -1, -1]],
+                    [[38, 39, 40],
+                     [-1, -1, -1]],
+                    [[41, 42, 43],
+                     [-1, -1, -1]]],
+                  [[[-1, -1, -1],
+                    [-1, -1, -1]],
+                   [[-1, -1, -1],
+                    [-1, -1, -1]],
+                   [[-1, -1, -1],
+                    [-1, -1, -1]]]]
+            Out.shape = (2, 3, 2, 3)
 
     Args:
         x (Variable): The input tensor variable.
@@ -6126,6 +6128,7 @@ def image_resize(input,
     Supporting resample methods:
 
         'BILINEAR' : Bilinear interpolation
+
         'NEAREST' : Nearest neighbor interpolation
 
     Args:
@@ -6781,7 +6784,7 @@ def crop(x, shape=None, offsets=None, name=None):
 
             # or
             z = fluid.layers.data(name="z", shape=[3, 5], dtype="float32")
-            crop = fluid.layers.crop(z, shape=[2, 3])
+            crop = fluid.layers.crop(z, shape=[-1, 2, 3])
 
     """
     helper = LayerHelper('crop', **locals())
@@ -7062,39 +7065,40 @@ def pad2d(input,
     than height-1. And the width dimension has the same condition.
 
     Example:
+        .. code-block:: text
 
-      Given that X is a channel of image from input:
+	      Given that X is a channel of image from input:
 
-      X = [[1, 2, 3],
-           [4, 5, 6]]
+	      X = [[1, 2, 3],
+		   [4, 5, 6]]
 
-      Case 0:
+	      Case 0:
 
-        paddings = [0, 1, 2, 3],
-        mode = 'constant'
-        pad_value = 0
+		paddings = [0, 1, 2, 3],
+		mode = 'constant'
+		pad_value = 0
 
-        Out = [[0, 0, 1, 2, 3, 0, 0, 0]
-               [0, 0, 4, 5, 6, 0, 0, 0]
-               [0, 0, 0, 0, 0, 0, 0, 0]]
+		Out = [[0, 0, 1, 2, 3, 0, 0, 0]
+		       [0, 0, 4, 5, 6, 0, 0, 0]
+		       [0, 0, 0, 0, 0, 0, 0, 0]]
 
-      Case 1:
+	      Case 1:
 
-        paddings = [0, 1, 2, 1],
-        mode = 'reflect'
+		paddings = [0, 1, 2, 1],
+		mode = 'reflect'
 
-        Out = [[3, 2, 1, 2, 3, 2]
-               [6, 5, 4, 5, 6, 5]
-               [3, 2, 1, 2, 3, 2]]
+		Out = [[3, 2, 1, 2, 3, 2]
+		       [6, 5, 4, 5, 6, 5]
+		       [3, 2, 1, 2, 3, 2]]
 
-      Case 2:
+	      Case 2:
 
-        paddings = [0, 1, 2, 1],
-        mode = 'edge'
+		paddings = [0, 1, 2, 1],
+		mode = 'edge'
 
-        Out = [[1, 1, 1, 2, 3, 3]
-               [4, 4, 4, 5, 6, 6]
-               [4, 4, 4, 5, 6, 6]]
+		Out = [[1, 1, 1, 2, 3, 3]
+		       [4, 4, 4, 5, 6, 6]
+		       [4, 4, 4, 5, 6, 6]]
 
 
     Args:
@@ -7332,13 +7336,13 @@ def prelu(x, mode, param_attr=None, name=None):
     Args:
         x (Variable): The input tensor.
         param_attr(ParamAttr|None): The parameter attribute for the learnable
-                       weight (alpha).
+          weight (alpha).
         mode (string): The mode for weight sharing. It supports all, channel
-                       and element. all: all elements share same weight
-                       channel:elements in a channel share same weight
-                       element:each element has a weight
+          and element. all: all elements share same weight
+          channel:elements in a channel share same weight
+          element:each element has a weight
         name(str|None): A name for this layer(optional). If set None, the layer
-                       will be named automatically.
+          will be named automatically.
 
     Returns:
         Variable: The output tensor with the same shape as input.
@@ -8380,6 +8384,29 @@ def mean(x, name=None):
     return out
 
 
+@templatedoc()
+def merge_selected_rows(x, name=None):
+    """
+    ${comment}
+
+    Args:
+        x(${x_type}): ${x_comment}
+        name(basestring|None): Name of the output.
+
+    Returns:
+        out(${out_type}): ${out_comment}
+    """
+
+    helper = LayerHelper("merge_selected_rows", **locals())
+    out = helper.create_variable_for_type_inference(dtype=x.dtype)
+    helper.append_op(
+        type="merge_selected_rows",
+        inputs={"X": x},
+        attrs={},
+        outputs={"Out": out})
+    return out
+
+
 @templatedoc()
 def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
     """
@@ -9032,3 +9059,26 @@ def bilinear_tensor_product(x,
 
     # add activation
     return helper.append_activation(out)
+
+
+@templatedoc()
+def get_tensor_from_selected_rows(x, name=None):
+    """
+    ${comment}
+
+    Args:
+        x(${x_type}): ${x_comment}
+        name(basestring|None): Name of the output.
+
+    Returns:
+        out(${out_type}): ${out_comment}
+    """
+
+    helper = LayerHelper('get_tensor_from_selected_rows', **locals())
+    out = helper.create_variable_for_type_inference(dtype=x.dtype)
+    helper.append_op(
+        type='get_tensor_from_selected_rows',
+        inputs={'X': x},
+        outputs={'Out': out},
+        attrs={})
+    return out
diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py
index ff32c00104..49a486cf0c 100644
--- a/python/paddle/fluid/layers/tensor.py
+++ b/python/paddle/fluid/layers/tensor.py
@@ -622,7 +622,7 @@ def reverse(x, axis):
     out = helper.create_variable_for_type_inference(dtype=x.dtype)
     helper.append_op(
         type='reverse',
-        inputs={'Input': x},
+        inputs={'X': x},
         outputs={'Out': [out]},
         attrs={'axis': axis})
     return out
diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py
index 829154f1b2..85af8fea13 100644
--- a/python/paddle/fluid/metrics.py
+++ b/python/paddle/fluid/metrics.py
@@ -222,13 +222,13 @@ class Precision(MetricBase):
     Examples:
         .. code-block:: python
 
-        metric = fluid.metrics.Precision()
-        for pass in range(PASSES):
-            metric.reset()
-            for data in train_reader():
-                loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
-            metric.update(preds=preds, labels=labels)
-            numpy_precision = metric.eval()
+            metric = fluid.metrics.Precision()
+            for pass in range(PASSES):
+                metric.reset()
+                for data in train_reader():
+                    loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
+                metric.update(preds=preds, labels=labels)
+                numpy_precision = metric.eval()
     """
 
     def __init__(self, name=None):
@@ -267,13 +267,13 @@ class Recall(MetricBase):
     Examples:
         .. code-block:: python
 
-        metric = fluid.metrics.Recall()
-        for pass in range(PASSES):
-            metric.reset()
-            for data in train_reader():
-                loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
-            metric.update(preds=preds, labels=labels)
-            numpy_recall = metric.eval()
+            metric = fluid.metrics.Recall()
+            for pass in range(PASSES):
+                metric.reset()
+                for data in train_reader():
+                    loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
+                metric.update(preds=preds, labels=labels)
+                numpy_recall = metric.eval()
     """
 
     def __init__(self, name=None):
@@ -449,8 +449,9 @@ class EditDistance(MetricBase):
                 distance_evaluator.update(distances, seq_num)
                 distance, instance_error = distance_evaluator.eval()
 
-        In the above example:
+    In the above example:
         'distance' is the average of the edit distance in a pass.
+
         'instance_error' is the instance error rate in a pass.
 
     """
diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py
index bdcd045341..c54c3963a1 100644
--- a/python/paddle/fluid/parallel_executor.py
+++ b/python/paddle/fluid/parallel_executor.py
@@ -95,7 +95,14 @@ class ParallelExecutor(object):
         self._places = []
         self._act_places = []
         if use_cuda:
-            for i in six.moves.range(core.get_cuda_device_count()):
+            gpus = []
+            gpus_env = os.getenv("FLAGS_selected_gpus")
+            if gpus_env:
+                gpus = [int(s) for s in gpus_env.split(",")]
+            else:
+                for i in six.moves.range(core.get_cuda_device_count()):
+                    gpus.append(i)
+            for i in gpus:
                 p = core.Place()
                 self._act_places.append(core.CUDAPlace(i))
                 p.set_place(self._act_places[-1])
@@ -128,9 +135,17 @@ class ParallelExecutor(object):
             build_strategy = BuildStrategy()
 
         build_strategy.num_trainers = num_trainers
+        build_strategy.trainer_id = trainer_id
 
         main = main_program
         main = main if main else framework.default_main_program()
+
+        trainers_endpoints = main._trainers_endpoints
+        if num_trainers > 1 and trainers_endpoints:
+            assert num_trainers == len(
+                trainers_endpoints), "num_trainers == len(end_points)"
+            build_strategy.trainers_endpoints = trainers_endpoints
+
         if scope == None:
             scope = executor.global_scope()
 
diff --git a/python/paddle/fluid/param_attr.py b/python/paddle/fluid/param_attr.py
index a51607bfdb..38ddf93198 100644
--- a/python/paddle/fluid/param_attr.py
+++ b/python/paddle/fluid/param_attr.py
@@ -50,8 +50,9 @@ class ParamAttr(object):
 
             w_param_attrs = fluid.ParamAttr(name="fc_weight",
                                             learning_rate=0.5,
-                                            regularizer=fluid.L2Decay(1.0),
+                                            regularizer=fluid.regularizer.L2Decay(1.0),
                                             trainable=True)
+	    x = fluid.layers.data(name='X', shape=[1], dtype='float32')
             y_predict = fluid.layers.fc(input=x, size=10, param_attr=w_param_attrs)
     """
 
diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py
index a2eca5541a..d99eaa0634 100644
--- a/python/paddle/fluid/tests/test_detection.py
+++ b/python/paddle/fluid/tests/test_detection.py
@@ -388,5 +388,18 @@ class TestGenerateProposals(unittest.TestCase):
         print(rpn_rois.shape)
 
 
+class TestYoloDetection(unittest.TestCase):
+    def test_yolov3_loss(self):
+        program = Program()
+        with program_guard(program):
+            x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
+            gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
+            gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
+            loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10,
+                                      0.5)
+
+            self.assertIsNotNone(loss)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/fluid/tests/test_gradient_clip.py b/python/paddle/fluid/tests/test_gradient_clip.py
deleted file mode 100644
index 266687fcd0..0000000000
--- a/python/paddle/fluid/tests/test_gradient_clip.py
+++ /dev/null
@@ -1,84 +0,0 @@
-#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import numpy as np
-import paddle
-import paddle.fluid as fluid
-
-BATCH_SIZE = 128
-CLIP = 1
-
-prog = fluid.framework.Program()
-with fluid.program_guard(main_program=prog):
-    image = fluid.layers.data(name='x', shape=[784], dtype='float32')
-
-    hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
-    hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
-    predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
-
-    label = fluid.layers.data(name='y', shape=[1], dtype='int64')
-
-    cost = fluid.layers.cross_entropy(input=predict, label=label)
-    avg_cost = fluid.layers.mean(cost)
-
-prog_clip = prog.clone()
-
-avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
-
-p_g = fluid.backward.append_backward(loss=avg_cost)
-p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
-
-with fluid.program_guard(main_program=prog_clip):
-    fluid.clip.set_gradient_clip(
-        fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
-    p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
-
-grad_list = [elem[1] for elem in p_g]
-grad_clip_list = [elem[1] for elem in p_g_clip]
-
-train_reader = paddle.batch(
-    paddle.reader.shuffle(
-        paddle.dataset.mnist.train(), buf_size=8192),
-    batch_size=BATCH_SIZE)
-
-place = fluid.CPUPlace()
-exe = fluid.Executor(place)
-feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
-exe.run(fluid.default_startup_program())
-
-count = 0
-for data in train_reader():
-    count += 1
-    if count > 5:
-        break
-    out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
-    out_clip = exe.run(prog_clip,
-                       feed=feeder.feed(data),
-                       fetch_list=grad_clip_list)
-    global_norm = 0
-    for v in out[1:]:
-        global_norm += np.sum(np.power(v, 2))
-    global_norm = np.sqrt(global_norm)
-
-    global_norm_clip = 0
-    for v in out_clip[1:]:
-        global_norm_clip += np.sum(np.power(v, 2))
-    global_norm_clip = np.sqrt(global_norm_clip)
-
-    if not np.isclose(
-            a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3):
-        exit(1)
-exit(0)
diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt
index 26035f303e..a4089ba3ca 100644
--- a/python/paddle/fluid/tests/unittests/CMakeLists.txt
+++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt
@@ -43,7 +43,7 @@ if(APPLE)
         list(REMOVE_ITEM TEST_OPS test_desc_clone)
         list(REMOVE_ITEM TEST_OPS test_program_code)
     endif(NOT WITH_DISTRIBUTE)
-    message(WARNING "These tests has been disabled in OSX before being fixed: \n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext")
+    message(WARNING "These tests has been disabled in OSX before being fixed:\n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext")
     # this op is not support on mac
     list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
     # TODO: add the unitest back when it fixed
@@ -95,13 +95,12 @@ if(WITH_DISTRIBUTE)
     if(NOT APPLE)
         set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200)
         set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
+	py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
+	set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
         # FIXME(typhoonzero): add these tests back
-	# py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
-	# set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
 	# py_test_modules(test_dist_transformer MODULES test_dist_transformer)
 	# set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
-        # TODO(typhoonzero): make dist test parallel when fix port management issue
-        set_tests_properties(test_dist_mnist test_dist_word2vec test_dist_ctr test_dist_simnet_bow test_dist_save_load test_dist_text_classification test_dist_mnist_batch_merge PROPERTIES RUN_SERIAL TRUE)
+        set_tests_properties(test_dist_ctr test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE)
     endif(NOT APPLE)
     py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
 endif()
diff --git a/python/paddle/fluid/tests/unittests/dist_save_load.py b/python/paddle/fluid/tests/unittests/dist_save_load.py
index cf62817956..faec535042 100644
--- a/python/paddle/fluid/tests/unittests/dist_save_load.py
+++ b/python/paddle/fluid/tests/unittests/dist_save_load.py
@@ -102,7 +102,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
 
         if args.mem_opt:
             fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
-        if args.is_dist:
+        if args.update_method == "pserver":
             t = self.get_transpiler(args.trainer_id,
                                     fluid.default_main_program(),
                                     args.endpoints, args.trainers,
@@ -147,7 +147,7 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
 
         def get_data():
             origin_batch = next(reader_generator)
-            if args.is_dist and args.use_reader_alloc:
+            if args.update_method == "pserver" and args.use_reader_alloc:
                 new_batch = []
                 for offset, item in enumerate(origin_batch):
                     if offset % 2 == args.trainer_id:
diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py
index 9f3f2f3481..6cd71e39e4 100644
--- a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py
+++ b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py
@@ -128,6 +128,12 @@ class TestIdentityActivation(TestConv2dFusionOp):
         self.activation = 'identity'
 
 
+class TestIdentityActivation(TestConv2dFusionOp):
+    def init_activation(self):
+        self.activation = 'identity'
+        self.add_residual_data = False
+
+
 class TestWithGroup(TestConv2dFusionOp):
     def init_group(self):
         self.groups = 3
diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py
new file mode 100644
index 0000000000..f0e1265e14
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+
+from test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1
+
+
+class TestMKLDNN(TestConv3dOp):
+    def init_kernel_type(self):
+        self.use_mkldnn = True
+        self.data_format = "NCHW"
+
+
+class TestMKLDNNCase1(TestCase1):
+    def init_kernel_type(self):
+        self.use_mkldnn = True
+        self.data_format = "NCHW"
+
+
+class TestMKLDNNGroup1(TestWithGroup1):
+    def init_kernel_type(self):
+        self.use_mkldnn = True
+        self.data_format = "NCHW"
+
+
+class TestMKLDNNGroup2(TestWithGroup2):
+    def init_kernel_type(self):
+        self.use_mkldnn = True
+        self.data_format = "NCHW"
+
+
+class TestMKLDNNWith1x1(TestWith1x1):
+    def init_kernel_type(self):
+        self.use_mkldnn = True
+        self.data_format = "NCHW"
+
+
+class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
+    def init_kernel_type(self):
+        self.use_mkldnn = True
+        self.data_format = "NCHW"
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py
index 69c5ab7a4a..c6b749fe09 100644
--- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py
+++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py
@@ -74,6 +74,8 @@ class TestConv3dOp(OpTest):
     def setUp(self):
         self.op_type = "conv3d"
         self.use_cudnn = False
+        self.use_mkldnn = False
+        self.data_format = "AnyLayout"
         self.dtype = np.float32
         self.init_kernel_type()
         self.init_group()
@@ -83,8 +85,7 @@ class TestConv3dOp(OpTest):
         conv3d_param = {
             'stride': self.stride,
             'pad': self.pad,
-            'dilations': self.dilations,
-            'data_format': 'AnyLayout'  # TODO(dzhwinter) : should be fix latter
+            'dilations': self.dilations
         }
 
         input = np.random.random(self.input_size).astype(self.dtype)
@@ -101,7 +102,9 @@ class TestConv3dOp(OpTest):
             'paddings': self.pad,
             'groups': self.groups,
             'dilations': self.dilations,
-            'use_cudnn': self.use_cudnn
+            'use_cudnn': self.use_cudnn,
+            'use_mkldnn': self.use_mkldnn,
+            'data_format': self.data_format
         }
         self.outputs = {'Output': output}
 
@@ -109,59 +112,35 @@ class TestConv3dOp(OpTest):
         return core.is_compiled_with_cuda() and self.use_cudnn
 
     def test_check_output(self):
-        if self.testcudnn():
-            place = core.CUDAPlace(0)
-            self.check_output_with_place(place, atol=1e-5)
-        else:
-            self.check_output()
+        place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
+        self.check_output_with_place(place, atol=1e-5)
 
     def test_check_grad(self):
         if self.dtype == np.float16:
             return
-        if self.testcudnn():
-            place = core.CUDAPlace(0)
-            self.check_grad_with_place(
-                place,
-                set(['Input', 'Filter']),
-                'Output',
-                max_relative_error=0.03)
-        else:
-            self.check_grad(
-                set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
+        place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
+        self.check_grad_with_place(
+            place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03)
 
     def test_check_grad_no_filter(self):
         if self.dtype == np.float16:
             return
-        if self.testcudnn():
-            place = core.CUDAPlace(0)
-            self.check_grad_with_place(
-                place, ['Input'],
-                'Output',
-                max_relative_error=0.03,
-                no_grad_set=set(['Filter']))
-        else:
-            self.check_grad(
-                ['Input'],
-                'Output',
-                max_relative_error=0.03,
-                no_grad_set=set(['Filter']))
+        place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
+        self.check_grad_with_place(
+            place, ['Input'],
+            'Output',
+            max_relative_error=0.03,
+            no_grad_set=set(['Filter']))
 
     def test_check_grad_no_input(self):
         if self.dtype == np.float16:
             return
-        if self.testcudnn():
-            place = core.CUDAPlace(0)
-            self.check_grad_with_place(
-                place, ['Filter'],
-                'Output',
-                max_relative_error=0.03,
-                no_grad_set=set(['Input']))
-        else:
-            self.check_grad(
-                ['Filter'],
-                'Output',
-                max_relative_error=0.03,
-                no_grad_set=set(['Input']))
+        place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
+        self.check_grad_with_place(
+            place, ['Input'],
+            'Output',
+            max_relative_error=0.03,
+            no_grad_set=set(['Input']))
 
     def init_test_case(self):
         self.pad = [0, 0, 0]
diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py
index 97e7ee6229..0a43f53658 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_base.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_base.py
@@ -76,12 +76,24 @@ class TestDistRunnerBase(object):
 
         if args.mem_opt:
             fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
-        if args.is_dist:
+        if args.update_method == "pserver":
             t = self.get_transpiler(args.trainer_id,
                                     fluid.default_main_program(),
                                     args.endpoints, args.trainers,
                                     args.sync_mode, args.dc_asgd)
             trainer_prog = t.get_trainer_program()
+        elif args.update_method == "nccl2":
+            # transpile for nccl2
+            config = fluid.DistributeTranspilerConfig()
+            config.mode = "nccl2"
+            nccl2_t = fluid.DistributeTranspiler(config=config)
+            nccl2_t.transpile(
+                args.trainer_id,
+                program=fluid.default_main_program(),
+                startup_program=fluid.default_startup_program(),
+                trainers=args.endpoints,
+                current_endpoint=args.current_endpoint)
+            trainer_prog = fluid.default_main_program()
         else:
             trainer_prog = fluid.default_main_program()
 
@@ -110,11 +122,20 @@ class TestDistRunnerBase(object):
                 len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
             mypass.set_int("num_repeats", args.batch_merge_repeat)
 
+        if args.update_method == "nccl2":
+            num_trainers = len(args.endpoints.split(","))
+            trainer_id = args.trainer_id
+        else:
+            num_trainers = 1
+            trainer_id = 0
+
         exe = fluid.ParallelExecutor(
             args.use_cuda,
             loss_name=avg_cost.name,
             exec_strategy=strategy,
-            build_strategy=build_stra)
+            build_strategy=build_stra,
+            num_trainers=num_trainers,
+            trainer_id=trainer_id)
 
         feed_var_list = [
             var for var in trainer_prog.global_block().vars.values()
@@ -126,7 +147,7 @@ class TestDistRunnerBase(object):
 
         def get_data():
             origin_batch = next(reader_generator)
-            if args.is_dist and args.use_reader_alloc:
+            if args.update_method != "local" and args.use_reader_alloc:
                 new_batch = []
                 for offset, item in enumerate(origin_batch):
                     if offset % 2 == args.trainer_id:
@@ -151,7 +172,11 @@ def runtime_main(test_class):
     parser.add_argument(
         '--role', type=str, required=True, choices=['pserver', 'trainer'])
     parser.add_argument('--endpoints', type=str, required=False, default="")
-    parser.add_argument('--is_dist', action='store_true')
+    parser.add_argument(
+        '--update_method',
+        type=str,
+        default="local",
+        choices=["pserver", "nccl2", "local"])
     parser.add_argument('--trainer_id', type=int, required=False, default=0)
     parser.add_argument('--trainers', type=int, required=False, default=1)
     parser.add_argument(
@@ -170,7 +195,7 @@ def runtime_main(test_class):
     args = parser.parse_args()
 
     model = test_class()
-    if args.role == "pserver" and args.is_dist:
+    if args.role == "pserver" and args.update_method == "pserver":
         model.run_pserver(args)
     else:
         model.run_trainer(args)
@@ -208,6 +233,7 @@ class TestDistBase(unittest.TestCase):
         self._use_reduce = False
         self._dc_asgd = False  # must use with async mode
         self._use_reader_alloc = True
+        self._nccl2_mode = False
         self._setup_config()
         self._after_setup_config()
 
@@ -218,7 +244,7 @@ class TestDistBase(unittest.TestCase):
 
     def start_pserver(self, model_file, check_error_log, required_envs):
         ps0_ep, ps1_ep = self._ps_endpoints.split(",")
-        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
+        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
         ps0_cmd = ps_cmd % \
                   (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                    self._trainers)
@@ -270,7 +296,8 @@ class TestDistBase(unittest.TestCase):
         else:
             env_local = {'CPU_NUM': '1'}
 
-        envs.update(env_local)
+        env_local.update(envs)
+        print("local_cmd: {}, env: {}".format(cmd, env_local))
 
         if check_error_log:
             err_log = open("/tmp/trainer.err.log", "wb")
@@ -278,21 +305,21 @@ class TestDistBase(unittest.TestCase):
                 cmd.split(" "),
                 stdout=subprocess.PIPE,
                 stderr=err_log,
-                env=envs)
+                env=env_local)
         else:
             local_proc = subprocess.Popen(
                 cmd.split(" "),
                 stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE,
-                env=envs)
+                env=env_local)
 
         local_out, local_err = local_proc.communicate()
 
         if check_error_log:
             err_log.close()
 
-        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
         sys.stderr.write('local_stderr: %s\n' % local_err)
+        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
 
         return pickle.loads(local_out)
 
@@ -303,7 +330,7 @@ class TestDistBase(unittest.TestCase):
 
         ps0_ep, ps1_ep = self._ps_endpoints.split(",")
 
-        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
+        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver"
         tr0_cmd = tr_cmd % \
                   (self._python_interp, model, self._ps_endpoints,
                    0, ps0_ep, self._trainers)
@@ -335,8 +362,8 @@ class TestDistBase(unittest.TestCase):
         env0.update(envs)
         env1.update(envs)
 
-        print("tr0_cmd:{}".format(tr0_cmd))
-        print("tr1_cmd:{}".format(tr1_cmd))
+        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
+        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
         tr0_pipe = open("/tmp/tr0_err.log", "wb")
         tr1_pipe = open("/tmp/tr1_err.log", "wb")
 
@@ -357,12 +384,9 @@ class TestDistBase(unittest.TestCase):
         # close trainer file
         tr0_pipe.close()
         tr1_pipe.close()
-
         ps0_pipe.close()
         ps1_pipe.close()
-        # FIXME: use terminate() instead of sigkill.
-        os.kill(ps0.pid, signal.SIGKILL)
-        os.kill(ps1.pid, signal.SIGKILL)
+
         ps0.terminate()
         ps1.terminate()
 
@@ -372,7 +396,71 @@ class TestDistBase(unittest.TestCase):
         sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
         sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
 
-        # return tr0_losses, tr1_losses
+        return pickle.loads(tr0_out), pickle.loads(tr1_out)
+
+    def _run_cluster_nccl2(self, model, envs, check_error_log):
+        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
+        worker_endpoints = self._ps_endpoints.split(",")
+        w0_ep, w1_ep = worker_endpoints
+
+        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2"
+        tr0_cmd = tr_cmd % \
+                  (self._python_interp, model, self._ps_endpoints,
+                   0, w0_ep)
+        tr1_cmd = tr_cmd % \
+                  (self._python_interp, model, self._ps_endpoints,
+                   1, w1_ep)
+
+        if self._mem_opt:
+            tr0_cmd += " --mem_opt"
+            tr1_cmd += " --mem_opt"
+        if self._use_reduce:
+            tr0_cmd += " --use_reduce"
+            tr1_cmd += " --use_reduce"
+        if self._use_reader_alloc:
+            tr0_cmd += " --use_reader_alloc"
+            tr1_cmd += " --use_reader_alloc"
+        if self.__use_cuda:
+            tr0_cmd += " --use_cuda"
+            tr1_cmd += " --use_cuda"
+            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
+            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
+        else:
+            env0 = {'CPU_NUM': '1'}
+            env1 = {'CPU_NUM': '1'}
+
+        env0.update(envs)
+        env1.update(envs)
+
+        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
+        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
+        tr0_pipe = open("/tmp/tr0_err.log", "wb")
+        tr1_pipe = open("/tmp/tr1_err.log", "wb")
+
+        tr0_proc = subprocess.Popen(
+            tr0_cmd.strip().split(" "),
+            stdout=subprocess.PIPE,
+            stderr=tr0_pipe,
+            env=env0)
+        tr1_proc = subprocess.Popen(
+            tr1_cmd.strip().split(" "),
+            stdout=subprocess.PIPE,
+            stderr=tr1_pipe,
+            env=env1)
+
+        tr0_out, tr0_err = tr0_proc.communicate()
+        tr1_out, tr1_err = tr1_proc.communicate()
+
+        # close trainer file
+        tr0_pipe.close()
+        tr1_pipe.close()
+
+        # print log
+        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
+        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
+        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
+        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)
+
         return pickle.loads(tr0_out), pickle.loads(tr1_out)
 
     def check_with_place(self,
@@ -387,20 +475,25 @@ class TestDistBase(unittest.TestCase):
             "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
             "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
             "FLAGS_cudnn_deterministic": "1",
-            "http_proxy": ""
+            "http_proxy": "",
+            "NCCL_P2P_DISABLE": "1"
         }
 
         required_envs.update(need_envs)
 
         if check_error_log:
-            required_envs["GLOG_v"] = "7"
+            required_envs["GLOG_v"] = "3"
             required_envs["GLOG_logtostderr"] = "1"
 
         local_losses\
             = self._run_local(model_file, required_envs,
                                        check_error_log)
-        tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs,
-                                                   check_error_log)
+        if self._nccl2_mode:
+            tr0_losses, tr1_losses = self._run_cluster_nccl2(
+                model_file, required_envs, check_error_log)
+        else:
+            tr0_losses, tr1_losses = self._run_cluster(
+                model_file, required_envs, check_error_log)
 
         for step_id in range(RUN_STEP):
             local_loss = local_losses[step_id]
diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py
index 81eb651878..630bed198f 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py
@@ -26,6 +26,19 @@ class TestDistMnist2x2(TestDistBase):
         self.check_with_place("dist_mnist.py", delta=1e-5)
 
 
+class TestDistMnistNCCL2(TestDistBase):
+    def _setup_config(self):
+        self._sync_mode = True
+        self._use_reduce = False
+        self._use_reader_alloc = False
+        self._nccl2_mode = True
+
+    def test_dist_train(self):
+        import paddle.fluid as fluid
+        if fluid.core.is_compiled_with_cuda():
+            self.check_with_place("dist_mnist.py", delta=1)
+
+
 class TestDistMnist2x2Lars(TestDistBase):
     def _setup_config(self):
         self._sync_mode = True
diff --git a/python/paddle/fluid/tests/unittests/test_dist_save_load.py b/python/paddle/fluid/tests/unittests/test_dist_save_load.py
index ea2b554dac..4588ca7c17 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_save_load.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_save_load.py
@@ -44,7 +44,7 @@ class TestDistSaveLoadDense2x2(TestDistBase):
         required_envs.update(need_envs)
 
         if check_error_log:
-            required_envs["GLOG_v"] = "7"
+            required_envs["GLOG_v"] = "3"
             required_envs["GLOG_logtostderr"] = "1"
 
         model_dir = tempfile.mkdtemp()
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
index 194387bc98..d9ad4e2e2c 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
@@ -769,6 +769,7 @@ class TestNCCL2Transpile(TranspilerTest):
 
             config = fluid.DistributeTranspilerConfig()
             config.mode = "nccl2"
+            config.wait_port = False
             t = fluid.DistributeTranspiler(config=config)
             t.transpile(
                 0,
diff --git a/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py
new file mode 100644
index 0000000000..021b950b3b
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py
@@ -0,0 +1,65 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import paddle.fluid.core as core
+import numpy as np
+from paddle.fluid.op import Operator
+
+
+class TestGetTensorFromSelectedRows(unittest.TestCase):
+    def get_places(self):
+        places = [core.CPUPlace()]
+        if core.is_compiled_with_cuda():
+            places.append(core.CUDAPlace(0))
+        return places
+
+    def check_with_place(self, place):
+        scope = core.Scope()
+        x_rows = [0, 5, 5, 4, 20]
+        height = 20
+        row_numel = 2
+
+        np_array = np.ones((len(x_rows), row_numel)).astype("float32")
+        np_array[1, :] = 2.0
+        np_array[2, :] = 3.0
+        np_array[3, :] = 4.0
+
+        # initialize input variable X
+        x = scope.var('X').get_selected_rows()
+        x.set_rows(x_rows)
+        x.set_height(height)
+        x_tensor = x.get_tensor()
+        x_tensor.set(np_array, place)
+
+        # initialize input variable Out
+        out = scope.var("Out").get_tensor()
+
+        op = Operator("get_tensor_from_selected_rows", X="X", Out="Out")
+
+        op.run(scope, place)
+
+        out_array = np.array(out)
+        self.assertEqual((5, 2), out_array.shape)
+        assert (out_array == np_array).all()
+
+    def test_check_output(self):
+        for place in self.get_places():
+            self.check_with_place(place)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py
new file mode 100644
index 0000000000..e49239da6d
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py
@@ -0,0 +1,161 @@
+#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import numpy as np
+import paddle
+import paddle.fluid.core as core
+import paddle.fluid as fluid
+
+
+def bow_net(data,
+            label,
+            dict_dim,
+            emb_dim=128,
+            hid_dim=128,
+            hid_dim2=96,
+            class_dim=2):
+    """
+    BOW net
+    This model is from https://github.com/PaddlePaddle/models:
+    fluid/PaddleNLP/text_classification/nets.py
+    """
+    emb = fluid.layers.embedding(
+        input=data, is_sparse=True, size=[dict_dim, emb_dim])
+    bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
+    bow_tanh = fluid.layers.tanh(bow)
+    fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
+    fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
+    prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
+    cost = fluid.layers.cross_entropy(input=prediction, label=label)
+    avg_cost = fluid.layers.mean(x=cost)
+
+    return avg_cost
+
+
+class TestGradientClip(unittest.TestCase):
+    def setUp(self):
+        self.word_dict = paddle.dataset.imdb.word_dict()
+        self.BATCH_SIZE = 2
+        self.train_data = paddle.batch(
+            paddle.dataset.imdb.train(self.word_dict),
+            batch_size=self.BATCH_SIZE)
+
+    def get_places(self):
+        places = [core.CPUPlace()]
+        if core.is_compiled_with_cuda():
+            places.append(core.CUDAPlace(0))
+        return places
+
+    def check_operators(self, place):
+        CLIP = 1
+
+        prog = fluid.framework.Program()
+        startup_program = fluid.framework.Program()
+        with fluid.program_guard(
+                main_program=prog, startup_program=startup_program):
+            image = fluid.layers.data(name='x', shape=[784], dtype='float32')
+            label = fluid.layers.data(name='y', shape=[1], dtype='int64')
+
+            hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
+            hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
+            predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
+
+            cost = fluid.layers.cross_entropy(input=predict, label=label)
+            avg_cost = fluid.layers.mean(cost)
+
+        prog_clip = prog.clone()
+        avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
+
+        p_g = fluid.backward.append_backward(loss=avg_cost)
+        p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
+
+        with fluid.program_guard(
+                main_program=prog_clip, startup_program=startup_program):
+            fluid.clip.set_gradient_clip(
+                fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
+            p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
+
+        grad_list = [elem[1] for elem in p_g]
+        grad_clip_list = [elem[1] for elem in p_g_clip]
+
+        train_reader = paddle.batch(
+            paddle.reader.shuffle(
+                paddle.dataset.mnist.train(), buf_size=8192),
+            batch_size=128)
+
+        exe = fluid.Executor(place)
+        feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
+        exe.run(startup_program)
+
+        count = 0
+        for data in train_reader():
+            count += 1
+            if count > 5:
+                break
+            out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
+            out_clip = exe.run(prog_clip,
+                               feed=feeder.feed(data),
+                               fetch_list=grad_clip_list)
+            global_norm = 0
+            for v in out:
+                global_norm += np.sum(np.power(v, 2))
+            global_norm = np.sqrt(global_norm)
+
+            global_norm_clip = 0
+            for v in out_clip:
+                global_norm_clip += np.sum(np.power(v, 2))
+            global_norm_clip = np.sqrt(global_norm_clip)
+
+            assert np.isclose(
+                a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3)
+
+    def check_sparse_gradient_clip(self, place):
+        prog = fluid.framework.Program()
+        startup_program = fluid.framework.Program()
+        with fluid.program_guard(
+                main_program=prog, startup_program=startup_program):
+            data = fluid.layers.data(
+                name="words", shape=[1], dtype="int64", lod_level=1)
+            label = fluid.layers.data(name="label", shape=[1], dtype="int64")
+            cost = bow_net(data, label, len(self.word_dict))
+
+            fluid.clip.set_gradient_clip(
+                clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
+
+            sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
+            sgd_optimizer.minimize(cost)
+
+        exe = fluid.Executor(place)
+        feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
+        exe.run(startup_program)
+
+        data = next(self.train_data())
+        val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0]
+        self.assertEqual((1, ), val.shape)
+        print(val)
+        self.assertFalse(np.isnan(val))
+
+    def test_operators(self):
+        self.check_operators(core.CPUPlace())
+
+    def test_sparse_gradient_clip(self):
+        for place in self.get_places():
+            self.check_sparse_gradient_clip(place)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py
new file mode 100644
index 0000000000..ce64da0478
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py
@@ -0,0 +1,73 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import paddle.fluid.core as core
+import numpy as np
+from paddle.fluid.op import Operator
+
+
+class TestMergeSelectedRows(unittest.TestCase):
+    def get_places(self):
+        places = [core.CPUPlace()]
+        if core.is_compiled_with_cuda():
+            places.append(core.CUDAPlace(0))
+        return places
+
+    def check_with_place(self, place):
+        scope = core.Scope()
+        x_rows = [0, 5, 5, 4, 20]
+        out_rows = [0, 4, 5, 20]
+        height = 20
+        row_numel = 2
+
+        np_array = np.ones((len(x_rows), row_numel)).astype("float32")
+        np_array[1, :] = 2.0
+        np_array[2, :] = 3.0
+        np_array[3, :] = 4.0
+
+        # initialize input variable X
+        x = scope.var('X').get_selected_rows()
+        x.set_rows(x_rows)
+        x.set_height(height)
+        x_tensor = x.get_tensor()
+        x_tensor.set(np_array, place)
+
+        # initialize input variable Out
+        out = scope.var("Out").get_selected_rows()
+
+        op = Operator("merge_selected_rows", X="X", Out="Out")
+
+        op.run(scope, place)
+
+        self.assertEqual(out.rows(), out_rows)
+        self.assertEqual(out.height(), height)
+
+        out_array = np.array(out.get_tensor())
+        self.assertEqual((4, 2), out_array.shape)
+
+        assert (out_array[0, :] == 1.0).all()
+        assert (out_array[1, :] == 4.0).all()
+        assert (out_array[2, :] == 5.0).all()
+        assert (out_array[3, :] == 1.0).all()
+
+    def test_check_output(self):
+        for place in self.get_places():
+            self.check_with_place(place)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
new file mode 100644
index 0000000000..544fe4b4f8
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
@@ -0,0 +1,215 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import unittest
+import numpy as np
+from op_test import OpTest
+
+from paddle.fluid import core
+
+
+def sigmoid(x):
+    return 1.0 / (1.0 + np.exp(-1.0 * x))
+
+
+def mse(x, y, num):
+    return ((y - x)**2).sum() / num
+
+
+def bce(x, y, mask):
+    x = x.reshape((-1))
+    y = y.reshape((-1))
+    mask = mask.reshape((-1))
+
+    error_sum = 0.0
+    count = 0
+    for i in range(x.shape[0]):
+        if mask[i] > 0:
+            error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i])
+            count += 1
+    return error_sum / (-1.0 * count)
+
+
+def box_iou(box1, box2):
+    b1_x1 = box1[0] - box1[2] / 2
+    b1_x2 = box1[0] + box1[2] / 2
+    b1_y1 = box1[1] - box1[3] / 2
+    b1_y2 = box1[1] + box1[3] / 2
+    b2_x1 = box2[0] - box2[2] / 2
+    b2_x2 = box2[0] + box2[2] / 2
+    b2_y1 = box2[1] - box2[3] / 2
+    b2_y2 = box2[1] + box2[3] / 2
+
+    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
+    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
+
+    inter_rect_x1 = max(b1_x1, b2_x1)
+    inter_rect_y1 = max(b1_y1, b2_y1)
+    inter_rect_x2 = min(b1_x2, b2_x2)
+    inter_rect_y2 = min(b1_y2, b2_y2)
+    inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max(
+        inter_rect_y2 - inter_rect_y1, 0)
+
+    return inter_area / (b1_area + b2_area + inter_area)
+
+
+def build_target(gtboxs, gtlabel, attrs, grid_size):
+    n, b, _ = gtboxs.shape
+    ignore_thresh = attrs["ignore_thresh"]
+    anchors = attrs["anchors"]
+    class_num = attrs["class_num"]
+    an_num = len(anchors) // 2
+    obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
+    noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
+    tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
+    ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
+    tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
+    th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
+    tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
+    tcls = np.zeros(
+        (n, an_num, grid_size, grid_size, class_num)).astype('float32')
+
+    for i in range(n):
+        for j in range(b):
+            if gtboxs[i, j, :].sum() == 0:
+                continue
+
+            gt_label = gtlabel[i, j]
+            gx = gtboxs[i, j, 0] * grid_size
+            gy = gtboxs[i, j, 1] * grid_size
+            gw = gtboxs[i, j, 2] * grid_size
+            gh = gtboxs[i, j, 3] * grid_size
+
+            gi = int(gx)
+            gj = int(gy)
+
+            gtbox = [0, 0, gw, gh]
+            max_iou = 0
+            for k in range(an_num):
+                anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]]
+                iou = box_iou(gtbox, anchor_box)
+                if iou > max_iou:
+                    max_iou = iou
+                    best_an_index = k
+                if iou > ignore_thresh:
+                    noobj_mask[i, best_an_index, gj, gi] = 0
+
+            obj_mask[i, best_an_index, gj, gi] = 1
+            noobj_mask[i, best_an_index, gj, gi] = 0
+            tx[i, best_an_index, gj, gi] = gx - gi
+            ty[i, best_an_index, gj, gi] = gy - gj
+            tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 *
+                                                               best_an_index])
+            th[i, best_an_index, gj, gi] = np.log(
+                gh / anchors[2 * best_an_index + 1])
+            tconf[i, best_an_index, gj, gi] = 1
+            tcls[i, best_an_index, gj, gi, gt_label] = 1
+
+    return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask)
+
+
+def YoloV3Loss(x, gtbox, gtlabel, attrs):
+    n, c, h, w = x.shape
+    an_num = len(attrs['anchors']) // 2
+    class_num = attrs["class_num"]
+    x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
+    pred_x = sigmoid(x[:, :, :, :, 0])
+    pred_y = sigmoid(x[:, :, :, :, 1])
+    pred_w = x[:, :, :, :, 2]
+    pred_h = x[:, :, :, :, 3]
+    pred_conf = sigmoid(x[:, :, :, :, 4])
+    pred_cls = sigmoid(x[:, :, :, :, 5:])
+
+    tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target(
+        gtbox, gtlabel, attrs, x.shape[2])
+
+    obj_mask_expand = np.tile(
+        np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num'])))
+    loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum())
+    loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum())
+    loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum())
+    loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum())
+    loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask)
+    loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask,
+                             noobj_mask)
+    loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand,
+                     obj_mask_expand)
+
+    return attrs['loss_weight_xy'] * (loss_x + loss_y) \
+            + attrs['loss_weight_wh'] * (loss_w + loss_h) \
+            + attrs['loss_weight_conf_target'] * loss_conf_target \
+            + attrs['loss_weight_conf_notarget'] * loss_conf_notarget \
+            + attrs['loss_weight_class'] * loss_class
+
+
+class TestYolov3LossOp(OpTest):
+    def setUp(self):
+        self.loss_weight_xy = 1.0
+        self.loss_weight_wh = 1.0
+        self.loss_weight_conf_target = 1.0
+        self.loss_weight_conf_notarget = 1.0
+        self.loss_weight_class = 1.0
+        self.initTestCase()
+        self.op_type = 'yolov3_loss'
+        x = np.random.random(size=self.x_shape).astype('float32')
+        gtbox = np.random.random(size=self.gtbox_shape).astype('float32')
+        gtlabel = np.random.randint(0, self.class_num,
+                                    self.gtbox_shape[:2]).astype('int32')
+
+        self.attrs = {
+            "anchors": self.anchors,
+            "class_num": self.class_num,
+            "ignore_thresh": self.ignore_thresh,
+            "loss_weight_xy": self.loss_weight_xy,
+            "loss_weight_wh": self.loss_weight_wh,
+            "loss_weight_conf_target": self.loss_weight_conf_target,
+            "loss_weight_conf_notarget": self.loss_weight_conf_notarget,
+            "loss_weight_class": self.loss_weight_class,
+        }
+
+        self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel}
+        self.outputs = {
+            'Loss': np.array(
+                [YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32')
+        }
+
+    def test_check_output(self):
+        place = core.CPUPlace()
+        self.check_output_with_place(place, atol=1e-3)
+
+    def test_check_grad_ignore_gtbox(self):
+        place = core.CPUPlace()
+        self.check_grad_with_place(
+            place, ['X'],
+            'Loss',
+            no_grad_set=set(["GTBox", "GTLabel"]),
+            max_relative_error=0.06)
+
+    def initTestCase(self):
+        self.anchors = [10, 13, 12, 12]
+        self.class_num = 10
+        self.ignore_thresh = 0.5
+        self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7)
+        self.gtbox_shape = (5, 10, 4)
+        self.loss_weight_xy = 2.5
+        self.loss_weight_wh = 0.8
+        self.loss_weight_conf_target = 1.5
+        self.loss_weight_conf_notarget = 0.5
+        self.loss_weight_class = 1.2
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index 5d348f0995..d21ec42dcc 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -125,13 +125,14 @@ def slice_variable(var_list, slice_count, min_block_size):
 
 class DistributeTranspilerConfig(object):
     """
-    slice_var_up (bool): Do Tensor slice for pservers, default is True.
-    split_method (PSDispatcher): RoundRobin or HashName can be used
-        try to choose the best method to balance loads for pservers.
-    min_block_size (int): Minimum splitted element number in block.
-        According:https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156
-        We can use bandwidth effiently when data size is larger than 2MB.If you
-        want to change it, please be sure you see the slice_variable function.
+    Args:
+        slice_var_up (bool): Do Tensor slice for pservers, default is True.
+        split_method (PSDispatcher): RoundRobin or HashName can be used
+          try to choose the best method to balance loads for pservers.
+        min_block_size (int): Minimum splitted element number in block.
+          According:https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156
+          We can use bandwidth effiently when data size is larger than 2MB.If you
+          want to change it, please be sure you see the slice_variable function.
     """
 
     slice_var_up = True
@@ -141,6 +142,7 @@ class DistributeTranspilerConfig(object):
     # supported modes: pserver, nccl2
     mode = "pserver"
     print_log = False
+    wait_port = True
 
 
 class DistributeTranspiler(object):
@@ -163,35 +165,34 @@ class DistributeTranspiler(object):
     Examples:
         .. code-block:: python
 
-           # for pserver mode
-           pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
-           trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
-           current_endpoint = "192.168.0.1:6174"
-           trainer_id = 0
-           trainers = 4
-           role = os.getenv("PADDLE_TRAINING_ROLE")
-
-           t = fluid.DistributeTranspiler()
-           t.transpile(
-                trainer_id, pservers=pserver_endpoints, trainers=trainers)
-           if role == "PSERVER":
-                pserver_program = t.get_pserver_program(current_endpoint)
-                pserver_startup_program = t.get_startup_program(current_endpoint,
+            # for pserver mode
+            pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
+            trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
+            current_endpoint = "192.168.0.1:6174"
+            trainer_id = 0
+            trainers = 4
+            role = os.getenv("PADDLE_TRAINING_ROLE")
+            t = fluid.DistributeTranspiler()
+            t.transpile(
+                 trainer_id, pservers=pserver_endpoints, trainers=trainers)
+            if role == "PSERVER":
+                 pserver_program = t.get_pserver_program(current_endpoint)
+                 pserver_startup_program = t.get_startup_program(current_endpoint,
                                                                 pserver_program)
-           elif role == "TRAINER":
-                trainer_program = t.get_trainer_program()
-
-           # for nccl2 mode
-           config = fluid.DistributeTranspilerConfig()
-           config.mode = "nccl2"
-           t = fluid.DistributeTranspiler(config=config)
-           t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep)
-           exe = fluid.ParallelExecutor(
-               use_cuda,
-               loss_name=loss_var.name,
-               num_trainers=len(trainers.split(",)),
-               trainer_id=trainer_id
-           )
+            elif role == "TRAINER":
+                 trainer_program = t.get_trainer_program()
+
+            # for nccl2 mode
+            config = fluid.DistributeTranspilerConfig()
+            config.mode = "nccl2"
+            t = fluid.DistributeTranspiler(config=config)
+            t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep)
+            exe = fluid.ParallelExecutor(
+                use_cuda,
+                loss_name=loss_var.name,
+                num_trainers=len(trainers.split(",)),
+                trainer_id=trainer_id
+            )
     """
 
     def __init__(self, config=None):
@@ -213,13 +214,16 @@ class DistributeTranspiler(object):
                          trainer_id,
                          trainers,
                          current_endpoint,
-                         startup_program=None):
+                         startup_program=None,
+                         wait_port=True):
         if not startup_program:
             startup_program = default_startup_program()
         if trainer_id >= 0:
             worker_endpoints = trainers.split(",")
             # send NCCL_ID to others or recv from trainer 0
             worker_endpoints.remove(current_endpoint)
+            if trainer_id == 0 and wait_port:
+                wait_server_ready(worker_endpoints)
 
             nccl_id_var = startup_program.global_block().create_var(
                 name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
@@ -301,11 +305,13 @@ class DistributeTranspiler(object):
 
         if self.config.mode == "nccl2":
             assert (isinstance(trainers, str))
+            self.origin_program._trainers_endpoints = trainers.split(",")
             self._transpile_nccl2(
                 trainer_id,
                 trainers,
                 current_endpoint,
-                startup_program=startup_program)
+                startup_program=startup_program,
+                wait_port=self.config.wait_port)
             return
 
         self.trainer_num = trainers
@@ -651,9 +657,6 @@ class DistributeTranspiler(object):
         # NOTE: assume blocks of the same variable is not distributed
         # on the same pserver, only change param/grad varnames for
         # trainers to fetch.
-        sys.stderr.write("get_pserver_program() is deprecated, call \
-get_pserver_programs() to get pserver main and startup \
-in a single call.")
         # step1
         pserver_program = Program()
         pserver_program.random_seed = self.origin_program.random_seed
@@ -921,18 +924,6 @@ in a single call.")
         Returns:
             Program: parameter server side startup program.
         """
-        sys.stderr.write("get_startup_program() is deprecated, call \
-get_pserver_programs() to get pserver main and startup \
-in a single call.")
-        if pserver_program != None:
-            sys.stderr.write("passing pserver_program to get_startup_program() \
-is deprecated, you can use new API get_pserver_programs() to \
-get both pserver main program and startup program.")
-        if startup_program != None:
-            sys.stderr.write("passing startup_program to get_startup_program() \
-is deprecated, use fluid.program_guard() or pass this argument \
-to transpile() call.")
-
         s_prog = Program()
         orig_s_prog = self.startup_program
         s_prog.random_seed = orig_s_prog.random_seed
diff --git a/tools/print_signatures.py b/tools/print_signatures.py
index e2805c4e7e..5c5266f904 100644
--- a/tools/print_signatures.py
+++ b/tools/print_signatures.py
@@ -15,7 +15,7 @@
 Print all signature of a python module in alphabet order.
 
 Usage:
-    ./print_signature  "paddle.fluid" > signature.txt
+    ./print_signature  "paddle.fluid,paddle.reader" > signature.txt
 """
 from __future__ import print_function
 
@@ -43,7 +43,8 @@ def visit_member(parent_name, member):
                 line.strip() for line in pydoc.render_doc(member).split('\n')
                 if "->" in line
             ])
-
+    elif inspect.isgetsetdescriptor(member):
+        return
     else:
         raise RuntimeError("Unsupported generate signature of member, type {0}".
                            format(str(type(member))))
@@ -63,7 +64,9 @@ def visit_all_module(mod):
             visit_member(mod.__name__, instance)
 
 
-visit_all_module(importlib.import_module(sys.argv[1]))
+modules = sys.argv[1].split(",")
+for m in modules:
+    visit_all_module(importlib.import_module(m))
 
 for name in member_dict:
     print(name, member_dict[name])