diff --git a/CMakeLists.txt b/CMakeLists.txt
index 98e1ac9f26..ed3120f399 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -175,6 +175,7 @@ include(external/any)       # download libn::any
 include(external/eigen)     # download eigen3
 include(external/pybind11)  # download pybind11
 include(external/cares)
+include(external/cub)
 
 if(WITH_DISTRIBUTE)
     if(WITH_GRPC)
diff --git a/cmake/external/cub.cmake b/cmake/external/cub.cmake
new file mode 100644
index 0000000000..c94849cf4b
--- /dev/null
+++ b/cmake/external/cub.cmake
@@ -0,0 +1,35 @@
+if(NOT WITH_GPU)
+  return()
+endif()
+
+include(ExternalProject)
+
+set(CUB_SOURCE_DIR ${THIRD_PARTY_PATH}/cub)
+set(CUB_INCLUDE_DIR ${CUB_SOURCE_DIR}/src/extern_cub)
+
+include_directories(${CUB_INCLUDE_DIR})
+
+ExternalProject_Add(
+  extern_cub
+  ${EXTERNAL_PROJECT_LOG_ARGS}
+  GIT_REPOSITORY "https://github.com/NVlabs/cub.git"
+  GIT_TAG        "v1.8.0"
+  PREFIX         ${CUB_SOURCE_DIR}
+  UPDATE_COMMAND ""
+  CONFIGURE_COMMAND ""
+  BUILD_COMMAND     ""
+  INSTALL_COMMAND   ""
+  TEST_COMMAND      ""
+)
+
+if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
+  set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cub_dummy.c)
+  file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
+  add_library(cub STATIC ${dummyfile})
+else()
+  add_library(cub INTERFACE)
+endif()
+
+add_dependencies(cub extern_cub)
+
+LIST(APPEND externl_project_dependencies cub)
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 3ef317bb7a..dd172ff9c9 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -336,6 +336,7 @@ paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=Non
 paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
 paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False))
 paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None)
+paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
 paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
 paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None)
 paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 5ca2ed8f96..a4fdbcb26d 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -275,7 +275,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
       if (strategy_.gradient_scale_ !=
           BuildStrategy::GradientScaleStrategy::kCustomized) {
         // TODO(paddle-dev): Why is there no input for this op_handle?
-        CreateScaleLossGradOp(&result);
+        auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
+        CreateScaleLossGradOp(&result, loss_grad_name);
       }
       // This assumes the backward generating code will ensure IsScaleLossOp
       // is true only for the op that scale the final scalar loss.
@@ -535,7 +536,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
   return got == sharded_var_device.end() ? -1 : got->second;
 }
 
-void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
+void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
+    ir::Graph *result, const std::string &loss_grad_name) const {
   for (size_t i = 0; i < places_.size(); ++i) {
 // Insert ScaleCost OpHandle
 #ifdef PADDLE_WITH_CUDA
@@ -558,10 +560,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
     // loss->pending_ops_.emplace_back(op_handle);
     // op_handle->inputs_.emplace_back(loss);
 
-    CreateOpOutput(result, op_handle,
-                   result->CreateEmptyNode(GradVarName(loss_var_name_),
-                                           ir::Node::Type::kVariable),
-                   places_[i], i);
+    CreateOpOutput(
+        result, op_handle,
+        result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
+        places_[i], i);
   }
 }
 
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 099dbe5abe..f2cb6bb1c8 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -75,7 +75,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
   void CreateComputationalOps(ir::Graph *result, ir::Node *node,
                               size_t num_places) const;
 
-  void CreateScaleLossGradOp(ir::Graph *result) const;
+  void CreateScaleLossGradOp(ir::Graph *result,
+                             const std::string &loss_grad_name) const;
+
   VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
                             int dst_dev_id) const;
   void CreateComputationalOp(ir::Graph *result, ir::Node *node,
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index c2800c972a..dad170ed78 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -330,12 +330,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
   }
 
   for (auto& op : ctx->ops_) {
-    VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
     op->Run(*local_scope, place_);
-    // NOTE! Please do not delete this line, it's usefull because the debug
-    // string before and after op.run are different, after run the output
-    // will have right shape which is usefull for debug.
-    VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
 
     if (FLAGS_benchmark) {
       VLOG(2) << "Memory used after operator " + op->Type() + " running: "
diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc
index 0c8acf71bf..d04f774496 100644
--- a/paddle/fluid/framework/operator.cc
+++ b/paddle/fluid/framework/operator.cc
@@ -127,7 +127,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
 }
 
 void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
-  VLOG(10) << "- " << DebugStringEx(&scope);
+  VLOG(4) << place << " " << DebugStringEx(&scope);
   if (platform::is_gpu_place(place)) {
 #ifndef PADDLE_WITH_CUDA
     PADDLE_THROW("Cannot run operator on place %s", place);
@@ -139,7 +139,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
   platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
   platform::RecordEvent record_event(Type(), pool.Get(place));
   RunImpl(scope, place);
-  VLOG(10) << "+ " << DebugStringEx(&scope);
+  VLOG(3) << place << " " << DebugStringEx(&scope);
 }
 
 bool OperatorBase::HasInputs(const std::string& name) const {
@@ -778,6 +778,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
     const ExecutionContext& ctx) const {
   auto& scope = ctx.scope();
   int data_type = -1;
+  std::string last_input_name;
   for (auto& input : this->inputs_) {
     for (auto& ipt_name : input.second) {
       auto* var = scope.FindVar(ipt_name);
@@ -794,9 +795,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
           int tmp = static_cast<int>(ToDataType(t->type()));
           PADDLE_ENFORCE(
               tmp == data_type || data_type == -1,
-              "DataType of Paddle Op %s must be the same. Get %d != %d", Type(),
-              data_type, tmp);
+              "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
+              Type(), last_input_name, data_type, ipt_name, tmp);
           data_type = tmp;
+          last_input_name = ipt_name;
         }
       }
     }
diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc
index e74f23ff96..63c3f0d7b3 100644
--- a/paddle/fluid/inference/api/api.cc
+++ b/paddle/fluid/inference/api/api.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
 
+#include <glog/logging.h>
 #include "paddle/fluid/inference/api/paddle_inference_api.h"
 
 namespace paddle {
@@ -40,19 +41,36 @@ PaddleBuf::PaddleBuf(PaddleBuf&& other)
 PaddleBuf::PaddleBuf(const PaddleBuf& other) { *this = other; }
 
 PaddleBuf& PaddleBuf::operator=(const PaddleBuf& other) {
+  if (!other.memory_owned_) {
+    data_ = other.data_;
+    length_ = other.length_;
+    memory_owned_ = other.memory_owned_;
+  } else {
+    Resize(other.length());
+    memcpy(data_, other.data(), other.length());
+    length_ = other.length();
+    memory_owned_ = true;
+  }
+  return *this;
+}
+
+PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) {
   // only the buffer with external memory can be copied
-  assert(!other.memory_owned_);
   data_ = other.data_;
   length_ = other.length_;
   memory_owned_ = other.memory_owned_;
+  other.data_ = nullptr;
+  other.length_ = 0;
+  other.memory_owned_ = false;
   return *this;
 }
 
 void PaddleBuf::Resize(size_t length) {
   // Only the owned memory can be reset, the external memory can't be changed.
   if (length_ == length) return;
-  assert(memory_owned_);
-  Free();
+  if (memory_owned_) {
+    Free();
+  }
   data_ = new char[length];
   length_ = length;
   memory_owned_ = true;
@@ -68,7 +86,7 @@ void PaddleBuf::Reset(void* data, size_t length) {
 void PaddleBuf::Free() {
   if (memory_owned_ && data_) {
     assert(length_ > 0);
-    delete static_cast<char*>(data_);
+    delete[] static_cast<char*>(data_);
     data_ = nullptr;
     length_ = 0;
   }
diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 59b0df7968..b24414e824 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -40,11 +40,12 @@ class PaddleBuf {
   // Copy only available when memory is managed externally.
   explicit PaddleBuf(const PaddleBuf&);
   PaddleBuf& operator=(const PaddleBuf&);
+  PaddleBuf& operator=(PaddleBuf&&);
   // Do not own the memory.
   PaddleBuf(void* data, size_t length)
       : data_(data), length_(length), memory_owned_{false} {}
   // Own memory.
-  explicit PaddleBuf(size_t length)
+  PaddleBuf(size_t length)
       : data_(new char[length]), length_(length), memory_owned_(true) {}
   // Resize to `length` bytes.
   void Resize(size_t length);
diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h
index eb8272e90c..bc3e95e904 100644
--- a/paddle/fluid/operators/elementwise_op_function.h
+++ b/paddle/fluid/operators/elementwise_op_function.h
@@ -534,8 +534,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                          const framework::Tensor& dout, int axis,
                          framework::Tensor* dx, framework::Tensor* dy,
                          DX_OP dx_op, DY_OP dy_op) {
-  const framework::DDim x_dim = x.dims();
-  const framework::DDim y_dim = y.dims();
+  const framework::DDim& x_dim = x.dims();
+  const framework::DDim& y_dim = y.dims();
   if (x.dims() == y.dims()) {
     ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
         ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
@@ -558,19 +558,19 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
                                  framework::Tensor* dx, framework::Tensor* dy,
                                  DX_OP dx_op, DY_OP dy_op) {
   if (dy == nullptr) {
-    const framework::DDim dx_dims = dout.dims();
+    const framework::DDim& dx_dims = dout.dims();
     auto dy_dims = dx_dims;
     ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
         ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
   } else {
     if (dout.dims() == dy->dims()) {
-      const framework::DDim dx_dims = dout.dims();
-      const framework::DDim dy_dims = dy->dims();
+      const framework::DDim& dx_dims = dout.dims();
+      const framework::DDim& dy_dims = dy->dims();
       ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
           ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
     } else {  // Y is a scalar
       auto dx_dims = dout.dims();
-      const framework::DDim dy_dims = dy->dims();
+      const framework::DDim& dy_dims = dy->dims();
       ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
           ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
     }
diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc
index 916cdad3fd..eb09470f37 100644
--- a/paddle/fluid/operators/parallel_do_op.cc
+++ b/paddle/fluid/operators/parallel_do_op.cc
@@ -163,12 +163,11 @@ class ParallelDoOp : public framework::OperatorBase {
       auto &place = places[place_idx];
       auto *cur_scope = sub_scopes[place_idx];
 
-      workers.emplace_back(
-          framework::Async([program, cur_scope, place, block, place_idx] {
-            framework::Executor executor(place);
-            executor.Run(*program, cur_scope, block->ID(),
-                         false /*create_local_scope*/);
-          }));
+      workers.emplace_back(framework::Async([program, cur_scope, place, block] {
+        framework::Executor executor(place);
+        executor.Run(*program, cur_scope, block->ID(),
+                     false /*create_local_scope*/);
+      }));
     }
     for (auto &worker : workers) {
       worker.wait();
@@ -239,12 +238,11 @@ class ParallelDoGradOp : public framework::OperatorBase {
       auto *cur_scope = sub_scopes[i];
 
       // execute
-      workers.emplace_back(
-          framework::Async([program, cur_scope, place, block, i] {
-            framework::Executor executor(place);
-            executor.Run(*program, cur_scope, block->ID(),
-                         false /*create_local_scope*/);
-          }));
+      workers.emplace_back(framework::Async([program, cur_scope, place, block] {
+        framework::Executor executor(place);
+        executor.Run(*program, cur_scope, block->ID(),
+                     false /*create_local_scope*/);
+      }));
     }
     for (auto &worker : workers) {
       worker.wait();
diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py
index 4e94ce8989..a7c3c5402e 100644
--- a/python/paddle/dataset/conll05.py
+++ b/python/paddle/dataset/conll05.py
@@ -29,13 +29,13 @@ __all__ = ['test, get_dict', 'get_embedding', 'convert']
 
 DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
 DATA_MD5 = '387719152ae52d60422c016e92a742fc'
-WORDDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/wordDict.txt'
+WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
 WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
-VERBDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/verbDict.txt'
+VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt'
 VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c'
-TRGDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/targetDict.txt'
+TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt'
 TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751'
-EMB_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/emb'
+EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2Femb'
 EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7'
 
 UNK_IDX = 0
diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py
index f0908c7378..7a157e3497 100644
--- a/python/paddle/dataset/wmt14.py
+++ b/python/paddle/dataset/wmt14.py
@@ -40,7 +40,7 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
              'wmt_shrinked_data/wmt14.tgz')
 MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
 # BLEU of this trained model is 26.92
-URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
+URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt%2Fwmt14.tgz'
 MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3'
 
 START = "<s>"
diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py
index 12cd5d918e..9de9e95045 100644
--- a/python/paddle/fluid/contrib/__init__.py
+++ b/python/paddle/fluid/contrib/__init__.py
@@ -14,5 +14,7 @@
 
 import decoder
 from decoder import *
+import memory_usage_calc
+from memory_usage_calc import *
 
-__all__ = decoder.__all__
+__all__ = decoder.__all__ + memory_usage_calc.__all__
diff --git a/python/paddle/fluid/contrib/memory_usage_calc.py b/python/paddle/fluid/contrib/memory_usage_calc.py
new file mode 100644
index 0000000000..5da846edb6
--- /dev/null
+++ b/python/paddle/fluid/contrib/memory_usage_calc.py
@@ -0,0 +1,102 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This module privides a memory usage calculate function for user.
+The purpose of this API is to allow users to estimate memory usage of
+a program under a special batch size, then user can set appropriate 
+batch size to fully utilize a GPU. 
+
+This API is still under active development and may change drastically.
+"""
+
+from .. import core
+from ..framework import Program, Variable
+
+__all__ = ['memory_usage']
+
+dtype_to_size = {
+    core.VarDesc.VarType.FP16: 2,
+    core.VarDesc.VarType.FP32: 4,
+    core.VarDesc.VarType.FP64: 8,
+    core.VarDesc.VarType.INT16: 2,
+    core.VarDesc.VarType.INT32: 4,
+    core.VarDesc.VarType.INT64: 8,
+    core.VarDesc.VarType.BOOL: 1,
+    core.VarDesc.VarType.UINT8: 1,
+}
+
+DEBUG = False
+
+
+def memory_usage(program, batch_size):
+    """
+    Get the estimate memory usage of program with input batch size.
+
+    Args:
+        program(Program): The current Program.
+        batch_size(int): The current input data batch_size.  
+    
+    Returns:
+        min_total_memory(float): the estimate memory usage lower bound.
+        max_total_memory(float): the estimate memory usage upper bound.
+        unit_str(string): the unit of estimate usage result.
+    
+    Examples:
+        
+        >>> import paddle.fluid as fluid
+        >>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
+                fluid.default_main_program(), batch_size=10)
+        >>> print "memory usage is about %.3f - %.3f %s" % \
+                (lower_usage, upper_usage, unit)
+
+    """
+
+    # Parameters check
+    if not isinstance(program, Program):
+        raise TypeError(
+            "Calculating Memory Usage requires Program as its Parameter."
+            "But you passed in %s" % (type(prgram)))
+    if batch_size <= 0:
+        raise ValueError("The batch size need to be positive.")
+
+    # Get the var_name list of first block and calculate
+    total_memory = 0.0
+    for var in program.global_block().vars.itervalues():
+        data_count = 1
+        for x in var.shape:
+            if x == -1:
+                data_count *= batch_size
+            else:
+                data_count *= x
+        var_memory = data_count * dtype_to_size[var.dtype]
+        if DEBUG:
+            print "%s memory usage: %d" % (var.name, var_memory)
+        total_memory += var_memory
+    if DEBUG:
+        print "total memory usage: %.2f" % (total_memory)
+
+    # Convert appropriate unit
+    unit_str = "B"
+    if total_memory > 1024:
+        total_memory /= 1024
+        unit_str = "KB"
+        if total_memory > 1024:
+            total_memory /= 1024
+            unit_str = "MB"
+
+    # Append extra memory consumption (5% - 10%)
+    min_total_memory = total_memory * 1.05
+    max_total_memory = total_memory * 1.1
+
+    return min_total_memory, max_total_memory, unit_str
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
index b24036326d..abd3721268 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
         self.origin_prog = main.clone()
         return main
 
-    def get_trainer(self, config=None):
-        t = self._transpiler_instance(config)
+    def get_trainer(self, config=None, sync_mode=True):
+        t = self._transpiler_instance(config, sync_mode)
         return t.get_trainer_program()
 
-    def get_pserver(self, ep, config=None):
-        t = self._transpiler_instance(config)
+    def get_pserver(self, ep, config=None, sync_mode=True):
+        t = self._transpiler_instance(config, sync_mode)
         pserver = t.get_pserver_program(ep)
         startup = t.get_startup_program(ep, pserver)
         return pserver, startup
 
-    def _transpiler_instance(self, config=None):
+    def _transpiler_instance(self, config=None, sync_mode=True):
         if not self.transpiler:
             main = self.get_main_program()
             self.transpiler = fluid.DistributeTranspiler(config=config)
@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
                 self.trainer_id,
                 program=main,
                 pservers=self.pserver_eps,
-                trainers=self.trainers)
+                trainers=self.trainers,
+                sync_mode=sync_mode)
 
         return self.transpiler
 
@@ -464,5 +465,76 @@ class TestDistLookupTable(TestDistLookupTableBase):
         self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
 
 
+class TestAsyncLocalLookupTable(TestDistLookupTableBase):
+    def net_conf(self):
+        self.network_with_table(is_sparse=True, is_distributed=False)
+
+    def transpiler_test_impl(self):
+        config = fluid.DistributeTranspilerConfig()
+        pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
+
+        self.assertEqual(len(pserver1.blocks), 3)
+        # 0 listen_and_serv
+        # 1 optimize for fc_w or fc_b adam
+        self.assertEqual([op.type for op in pserver1.blocks[1].ops],
+                         ["adam", "scale", "scale"])
+        # 2 optimize for table adam
+        # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
+        self.assertEqual([op.type for op in pserver1.blocks[2].ops],
+                         ["adam", "scale", "scale"])
+
+        trainer = self.get_trainer(config)
+        self.assertEqual(len(trainer.blocks), 1)
+        ops = [
+            'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
+            'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
+            'fill_constant', 'mean_grad', 'cross_entropy_grad',
+            'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
+            'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
+            'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv',
+            'recv', 'recv', 'concat'
+        ]
+        self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
+
+
+class TestAsyncDistLookupTable(TestDistLookupTableBase):
+    def net_conf(self):
+        self.network_with_table(is_sparse=True, is_distributed=True)
+
+    def transpiler_test_impl(self):
+        config = fluid.DistributeTranspilerConfig()
+
+        pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
+
+        self.assertEqual(len(pserver1.blocks), 6)
+        # 0 listen_and_serv
+        # 1 optimize for fc_w or fc_b adam
+        self.assertEqual([op.type for op in pserver1.blocks[1].ops],
+                         ["adam", "scale", "scale"])
+        # 2 optimize for table sgd
+        self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
+        # 3 prefetch -> lookup_sparse_table for data0
+        self.assertEqual([op.type for op in pserver1.blocks[3].ops],
+                         ["lookup_sparse_table"])
+        # 4 prefetch -> lookup_sparse_table for data1
+        self.assertEqual([op.type for op in pserver1.blocks[4].ops],
+                         ["lookup_sparse_table"])
+        # 5 save table
+        self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
+
+        trainer = self.get_trainer(config)
+        self.assertEqual(len(trainer.blocks), 1)
+        ops = [
+            'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
+            'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
+            'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
+            'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
+            'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
+            'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
+            'sum', 'split_ids', 'send', 'recv', 'recv'
+        ]
+        self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_memory_usage.py b/python/paddle/fluid/tests/unittests/test_memory_usage.py
new file mode 100644
index 0000000000..f9daf83652
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_memory_usage.py
@@ -0,0 +1,69 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import paddle
+import paddle.fluid as fluid
+import contextlib
+import unittest
+
+
+def train_simulator(test_batch_size=10):
+    if test_batch_size <= 0:
+        raise ValueError("batch_size should be a positive integeral value, "
+                         "but got batch_size={}".format(test_batch_size))
+
+    x = fluid.layers.data(name='x', shape=[13], dtype='float32')
+    y_predict = fluid.layers.fc(input=x, size=1, act=None)
+    y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+
+    cost = fluid.layers.square_error_cost(input=y_predict, label=y)
+    avg_cost = fluid.layers.mean(cost)
+
+    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
+    sgd_optimizer.minimize(avg_cost)
+
+    # Calculate memory usage in current network config 
+    lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
+        fluid.default_main_program(), batch_size=test_batch_size)
+
+    print("memory usage is about %.3f - %.3f %s" %
+          (lower_usage, upper_usage, unit))
+
+
+class TestMemoryUsage(unittest.TestCase):
+    def test_with_unit_B(self):
+        with self.program_scope_guard():
+            train_simulator()
+
+    def test_with_unit_KB(self):
+        with self.program_scope_guard():
+            train_simulator(test_batch_size=1000)
+
+    def test_with_unit_MB(self):
+        with self.program_scope_guard():
+            train_simulator(test_batch_size=100000)
+
+    @contextlib.contextmanager
+    def program_scope_guard(self):
+        prog = fluid.Program()
+        startup_prog = fluid.Program()
+        scope = fluid.core.Scope()
+        with fluid.scope_guard(scope):
+            with fluid.program_guard(prog, startup_prog):
+                yield
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index b0a100e1db..820509bbcc 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -293,14 +293,15 @@ class DistributeTranspiler(object):
                     RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
                 })
 
-        program.global_block().append_op(
-            type="fetch_barrier",
-            inputs={},
-            outputs={},
-            attrs={
-                "endpoints": pserver_endpoints,
-                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
-            })
+        if self.sync_mode:
+            program.global_block().append_op(
+                type="fetch_barrier",
+                inputs={},
+                outputs={},
+                attrs={
+                    "endpoints": pserver_endpoints,
+                    RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
+                })
 
         for varname, splited_var in self.param_var_mapping.iteritems():
             if len(splited_var) <= 1:
diff --git a/python/paddle/v2/dataset/conll05.py b/python/paddle/v2/dataset/conll05.py
index 0d544efac9..8312900dc4 100644
--- a/python/paddle/v2/dataset/conll05.py
+++ b/python/paddle/v2/dataset/conll05.py
@@ -29,13 +29,13 @@ __all__ = ['test, get_dict', 'get_embedding', 'convert']
 
 DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
 DATA_MD5 = '387719152ae52d60422c016e92a742fc'
-WORDDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/wordDict.txt'
+WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
 WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
-VERBDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/verbDict.txt'
+VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt'
 VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c'
-TRGDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/targetDict.txt'
+TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt'
 TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751'
-EMB_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/emb'
+EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2Femb'
 EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7'
 
 UNK_IDX = 0
diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py
index 5104e29051..1ec210f265 100644
--- a/python/paddle/v2/dataset/wmt14.py
+++ b/python/paddle/v2/dataset/wmt14.py
@@ -41,7 +41,7 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
              'wmt_shrinked_data/wmt14.tgz')
 MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
 # BLEU of this trained model is 26.92
-URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
+URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt%2Fwmt14.tgz'
 MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3'
 
 START = "<s>"