From a53e8a8da6a96e559c0ca38367024f2c5b04c021 Mon Sep 17 00:00:00 2001
From: Brian Liu <brian.liu@intel.com>
Date: Sat, 9 Jun 2018 09:23:14 +0800
Subject: [PATCH 1/5] Update MKLDNN integration framework to support Paddle
 multi-instances

Make all blob info saved in global device context to be thread based.
Meanwhile save thread id in thread local storage in ParallelDo
---
 paddle/fluid/platform/device_context.cc | 65 +++++++++++++++++++------
 paddle/fluid/platform/device_context.h  | 10 +++-
 2 files changed, 58 insertions(+), 17 deletions(-)

diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index 7d1cf57253..690ba55279 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -25,6 +25,14 @@ namespace platform {
 
 DeviceContextPool* DeviceContextPool::pool = nullptr;
 
+namespace {
+// Current thread's id.
+thread_local int cur_thread_id = 0;
+}
+
+void set_cur_thread_id(int tid) { cur_thread_id = tid; }
+int get_cur_thread_id(void) { return cur_thread_id; }
+
 platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
   auto it = device_contexts_.find(place);
   if (it == device_contexts_.end()) {
@@ -296,38 +304,65 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
 
 #ifdef PADDLE_WITH_MKLDNN
 MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
-    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
-  p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
+    : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
+  p_blobmap_.reset(new BlobMap());
+  p_mutex_.reset(new std::mutex());
 }
 
 void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                   std::shared_ptr<void> data) const {
-  std::unordered_map<std::string, std::shared_ptr<void>>* p;
-  p = p_blobs_.get();
+  BlobMap* pMap = p_blobmap_.get();
+  std::shared_ptr<KeyBlob> pBlob = nullptr;
+
+  int tid = platform::get_cur_thread_id();
 
-  auto it = p->find(name);
+  std::lock_guard<std::mutex> lock(*p_mutex_.get());
 
-  if (it == p->end()) {
-    (*p)[name] = data;  // create new blob
+  // Find KeyBlob for current thread
+  auto map_it = pMap->find(tid);
+
+  if (map_it == pMap->end()) {
+    // 1st time to set blob in current thread
+    pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
+    (*pMap)[tid] = pBlob;
   } else {
-    it->second = data;  // set data to existing blob
+    pBlob = map_it->second;
   }
 
+  // Find Key in found (or newly created) KeyBlob
+  auto key_it = pBlob->find(name);
+
+  if (key_it == pBlob->end()) {
+    (*pBlob)[name] = data;  // create new blob
+  } else {
+    key_it->second = data;  // set data to existing blob
+  }
+
+  // lock will be automatically released when out of scope
   return;
 }
 
 std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
     const std::string& name) const {
-  std::unordered_map<std::string, std::shared_ptr<void>>* p;
-  p = p_blobs_.get();
+  BlobMap* pMap = p_blobmap_.get();
+  std::shared_ptr<KeyBlob> pBlob = nullptr;
 
-  auto it = p->find(name);
+  int tid = platform::get_cur_thread_id();
 
-  if (it != p->end()) {
-    return it->second;
-  }
+  std::lock_guard<std::mutex> lock(*p_mutex_.get());
+
+  // Find KeyBlob for current thread firstly
+  auto map_it = pMap->find(tid);
+  if (map_it == pMap->end()) return nullptr;
+  pBlob = map_it->second;
+
+  // Find Blob via name
+  auto key_it = pBlob->find(name);
+
+  if (key_it == pBlob->end()) return nullptr;
 
-  return nullptr;
+  // lock will be automatically released when out of scope
+  return key_it->second;
 }
 
 #endif
diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h
index 999bbe00f1..1527c9f324 100644
--- a/paddle/fluid/platform/device_context.h
+++ b/paddle/fluid/platform/device_context.h
@@ -39,6 +39,12 @@ limitations under the License. */
 namespace paddle {
 namespace platform {
 
+using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
+using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
+
+void set_cur_thread_id(int);
+int get_cur_thread_id(void);
+
 class DeviceContext {
  public:
   virtual ~DeviceContext() {}
@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
 
  private:
   mkldnn::engine engine_;
-  std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>>
-      p_blobs_;
+  std::shared_ptr<BlobMap> p_blobmap_;
+  std::shared_ptr<std::mutex> p_mutex_;
 };
 #endif
 

From 741cb33bd97dcb121d866acf18458f95527f3a11 Mon Sep 17 00:00:00 2001
From: Sylwester Fraczek <sylwester.fraczek@intel.com>
Date: Tue, 16 Oct 2018 14:52:45 +0200
Subject: [PATCH 2/5] test multithreading

---
 paddle/fluid/inference/api/helper.h              | 3 ++-
 paddle/fluid/inference/tests/api/tester_helper.h | 1 +
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h
index 24f59cf43a..e46dc13269 100644
--- a/paddle/fluid/inference/api/helper.h
+++ b/paddle/fluid/inference/api/helper.h
@@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid,
                       double latency, int epoch = 1) {
   LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat
             << ", threads: " << num_threads << ", thread id: " << tid
-            << ", latency: " << latency << "ms ======";
+            << ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f)
+            << " ======";
   if (epoch > 1) {
     int samples = batch_size * epoch;
     LOG(INFO) << "====== sample number: " << samples
diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h
index 5589b58b06..42072895fc 100644
--- a/paddle/fluid/inference/tests/api/tester_helper.h
+++ b/paddle/fluid/inference/tests/api/tester_helper.h
@@ -139,6 +139,7 @@ void TestMultiThreadPrediction(
   }
   for (int tid = 0; tid < num_threads; ++tid) {
     threads.emplace_back([&, tid]() {
+      platform::set_cur_thread_id(static_cast<int>(tid) + 1);
       // Each thread should have local inputs and outputs.
       // The inputs of each thread are all the same.
       std::vector<std::vector<PaddleTensor>> inputs_tid = inputs;

From bba0c4a9f2d8ea8936595e438cc6abca0e0f710b Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Fri, 26 Oct 2018 15:21:23 +0800
Subject: [PATCH 3/5] delete unused codes.

test=develop
---
 paddle/fluid/framework/ir/graph.cc | 62 ------------------------------
 paddle/fluid/framework/ir/node.h   |  2 +
 paddle/fluid/framework/op_desc.h   |  4 --
 3 files changed, 2 insertions(+), 66 deletions(-)

diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc
index 398f709596..11102bc776 100644
--- a/paddle/fluid/framework/ir/graph.cc
+++ b/paddle/fluid/framework/ir/graph.cc
@@ -24,68 +24,6 @@ namespace paddle {
 namespace framework {
 namespace ir {
 
-std::vector<std::string> FindDistTrainSendVars(
-    const std::vector<ir::Node *> &nodes) {
-  std::vector<std::string> send_vars;
-  // since parameters are all in block 0,
-  // it's enough to only scan send ops in block 0
-  for (auto &node : nodes) {
-    auto op_vars = node->Op()->InputArgumentNames();
-    send_vars.reserve(send_vars.size() +
-                      std::distance(op_vars.begin(), op_vars.end()));
-    send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
-  }
-  return send_vars;
-}
-
-std::vector<std::string> FindDistTrainRecvVars(
-    const std::vector<ir::Node *> &nodes) {
-  std::vector<std::string> recv_vars;
-  for (auto &node : nodes) {
-    auto op_vars = node->Op()->OutputArgumentNames();
-    recv_vars.reserve(recv_vars.size() +
-                      std::distance(op_vars.begin(), op_vars.end()));
-    recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
-  }
-  return recv_vars;
-}
-
-bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
-                   const std::vector<std::string> &recv_vars) {
-  if (send_vars.size() == 0 || recv_vars.size() == 0) {
-    return false;
-  }
-
-  /**
-   * Check any of opvars contains `.block` and in sendvars
-   */
-  auto checker = [](const std::vector<std::string> &opvars,
-                    const std::vector<std::string> &rpc_vars) -> bool {
-    for (auto &var : opvars) {
-      // a variable name with the suffix `.block` means it's a splited
-      // variable by (DistributeTranspiler)
-      // [python/paddle/fluid/transpiler/distribute_transpiler.py]
-      if (var.find(".block") != std::string::npos &&
-          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
-        return true;
-      }
-    }
-    return false;
-  };
-
-  std::vector<std::string> input_var_names;
-  std::vector<std::string> output_var_names;
-  for (ir::Node *input : node->inputs) {
-    input_var_names.push_back(input->Name());
-  }
-  for (ir::Node *output : node->outputs) {
-    output_var_names.push_back(output->Name());
-  }
-
-  return checker(output_var_names, send_vars) ||
-         checker(input_var_names, recv_vars);
-}
-
 Graph::Graph(const ProgramDesc &program) : program_(program) {
   // Make the nodes id start from 0.
   Node::ResetId();
diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h
index 5d6da9f1d7..d6d42f5e92 100644
--- a/paddle/fluid/framework/ir/node.h
+++ b/paddle/fluid/framework/ir/node.h
@@ -44,6 +44,7 @@ class Node {
     return op_desc_.get();
   }
 
+  // Please don't use this API!
   int id() const { return id_; }
 
   bool IsOp() const { return type_ == Type::kOperation; }
@@ -92,6 +93,7 @@ class Node {
   Node() = delete;
 
   static int count_;
+  // Please don't use this API or make this public.
   static void ResetId() { count_ = 0; }
   DISABLE_COPY_AND_ASSIGN(Node);
 };
diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h
index 440e0509be..30c8a26c3d 100644
--- a/paddle/fluid/framework/op_desc.h
+++ b/paddle/fluid/framework/op_desc.h
@@ -121,10 +121,6 @@ class OpDesc {
 
   BlockDesc *Block() { return this->block_; }
 
-  const BlockDesc &BlockRef() const { return *this->block_; }
-
-  void SetBlock(BlockDesc *block) { this->block_ = block; }
-
  private:
   template <typename MapType>
   static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {

From 2098b42584f0d6c588d2ec62f6b37a4dc8916e68 Mon Sep 17 00:00:00 2001
From: Sylwester Fraczek <sylwester.fraczek@intel.com>
Date: Wed, 24 Oct 2018 10:26:07 +0200
Subject: [PATCH 4/5] review fixes (Teamcity fails)

test=develop
---
 paddle/fluid/inference/tests/api/tester_helper.h |  2 ++
 paddle/fluid/platform/device_context.cc          | 16 ++++++++--------
 paddle/fluid/platform/device_context.h           | 12 ++++++------
 3 files changed, 16 insertions(+), 14 deletions(-)

diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h
index 42072895fc..19c3f532d5 100644
--- a/paddle/fluid/inference/tests/api/tester_helper.h
+++ b/paddle/fluid/inference/tests/api/tester_helper.h
@@ -139,7 +139,9 @@ void TestMultiThreadPrediction(
   }
   for (int tid = 0; tid < num_threads; ++tid) {
     threads.emplace_back([&, tid]() {
+#ifdef PADDLE_WITH_MKLDNN
       platform::set_cur_thread_id(static_cast<int>(tid) + 1);
+#endif
       // Each thread should have local inputs and outputs.
       // The inputs of each thread are all the same.
       std::vector<std::vector<PaddleTensor>> inputs_tid = inputs;
diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index 690ba55279..b0de636de4 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -25,14 +25,6 @@ namespace platform {
 
 DeviceContextPool* DeviceContextPool::pool = nullptr;
 
-namespace {
-// Current thread's id.
-thread_local int cur_thread_id = 0;
-}
-
-void set_cur_thread_id(int tid) { cur_thread_id = tid; }
-int get_cur_thread_id(void) { return cur_thread_id; }
-
 platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
   auto it = device_contexts_.find(place);
   if (it == device_contexts_.end()) {
@@ -309,6 +301,14 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
   p_mutex_.reset(new std::mutex());
 }
 
+namespace {
+// Current thread's id.
+thread_local int cur_thread_id = 0;
+}
+
+void set_cur_thread_id(int tid) { cur_thread_id = tid; }
+int get_cur_thread_id(void) { return cur_thread_id; }
+
 void MKLDNNDeviceContext::SetBlob(const std::string& name,
                                   std::shared_ptr<void> data) const {
   BlobMap* pMap = p_blobmap_.get();
diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h
index 1527c9f324..942e13a724 100644
--- a/paddle/fluid/platform/device_context.h
+++ b/paddle/fluid/platform/device_context.h
@@ -39,12 +39,6 @@ limitations under the License. */
 namespace paddle {
 namespace platform {
 
-using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
-using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
-
-void set_cur_thread_id(int);
-int get_cur_thread_id(void);
-
 class DeviceContext {
  public:
   virtual ~DeviceContext() {}
@@ -182,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
 #endif
 
 #ifdef PADDLE_WITH_MKLDNN
+using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
+using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
+
+void set_cur_thread_id(int);
+int get_cur_thread_id(void);
+
 class MKLDNNDeviceContext : public CPUDeviceContext {
  public:
   explicit MKLDNNDeviceContext(CPUPlace place);

From 26200f2e420566cba3112ee725197a1c12c8682b Mon Sep 17 00:00:00 2001
From: Wu Yi <typhoonzero1986@gmail.com>
Date: Mon, 29 Oct 2018 14:10:08 +0800
Subject: [PATCH 5/5] [1.1] [project] train imagenet using large batch size
 (#13766)

* fix nccl2 lars dist support

* put lars in momentum op

* add tests lars

* fix ci

* fix cpu kernel

* soft warning

* remove lars in test_recognize_digits.py

* move to another op

* add file

* update api.spec test=develop

* update test=develop

* fix api.spec test=develop

* wip

* wip, finish grad merge ops

* wip, finish graph build

* wip test running

* work on 1 gpu

* workable version

* update

* fix tests

* fuse broadcast op

* fix compile failed

* refine

* add batch merge test mnist

* fix CI test=develop

* fix build

* use independent bn params for batch merge test=develop

* update api.spec

* follow comments and for test

* wip

* refine tests test=develop

* follow comments test=develop

* remove startup bn modify test=develop

* follow comments test=develop

* fix merge test=develop
---
 benchmark/fluid/args.py                       |   5 +
 benchmark/fluid/fluid_benchmark.py            |   2 +-
 paddle/fluid/API.spec                         |   2 +
 paddle/fluid/framework/details/CMakeLists.txt |   6 +-
 .../framework/details/broadcast_op_handle.cc  |  21 +-
 .../framework/details/broadcast_op_handle.h   |   5 +-
 .../fluid/framework/details/build_strategy.cc |   1 +
 .../fluid/framework/details/build_strategy.h  |   2 +
 .../details/fused_broadcast_op_handle.cc      |  55 +++
 .../details/fused_broadcast_op_handle.h       |  57 ++++
 .../details/multi_devices_graph_pass.cc       |  62 +++-
 .../details/multi_devices_graph_pass.h        |   7 +-
 paddle/fluid/framework/ir/CMakeLists.txt      |   1 +
 paddle/fluid/framework/ir/graph.cc            |  13 +-
 paddle/fluid/framework/ir/graph.h             |   6 +
 .../framework/ir/multi_batch_merge_pass.cc    | 315 ++++++++++++++++++
 .../framework/ir/multi_batch_merge_pass.h     |  44 +++
 paddle/fluid/framework/parallel_executor.cc   |  24 +-
 paddle/fluid/operators/lars_momentum_op.cc    |  86 +++++
 paddle/fluid/operators/lars_momentum_op.cu    |  94 ++++++
 paddle/fluid/operators/lars_momentum_op.h     |  72 ++++
 paddle/fluid/operators/momentum_op.cc         |  48 ---
 paddle/fluid/operators/momentum_op.h          |  48 +++
 paddle/fluid/pybind/pybind.cc                 |  10 +-
 .../fluid/layers/learning_rate_scheduler.py   |  26 +-
 python/paddle/fluid/optimizer.py              |  91 ++++-
 .../fluid/tests/unittests/dist_mnist.py       |   2 +-
 .../tests/unittests/dist_mnist_batch_merge.py |  80 +++++
 .../fluid/tests/unittests/dist_mnist_lars.py  |  73 ++++
 .../fluid/tests/unittests/test_dist_base.py   |  27 +-
 .../fluid/tests/unittests/test_dist_mnist.py  |   9 +
 .../unittests/test_dist_mnist_batch_merge.py  |  67 ++++
 .../fluid/tests/unittests/test_momentum_op.py |  39 +++
 .../fluid/transpiler/distribute_transpiler.py |   6 +-
 34 files changed, 1300 insertions(+), 106 deletions(-)
 create mode 100644 paddle/fluid/framework/details/fused_broadcast_op_handle.cc
 create mode 100644 paddle/fluid/framework/details/fused_broadcast_op_handle.h
 create mode 100644 paddle/fluid/framework/ir/multi_batch_merge_pass.cc
 create mode 100644 paddle/fluid/framework/ir/multi_batch_merge_pass.h
 create mode 100644 paddle/fluid/operators/lars_momentum_op.cc
 create mode 100644 paddle/fluid/operators/lars_momentum_op.cu
 create mode 100644 paddle/fluid/operators/lars_momentum_op.h
 create mode 100644 python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py
 create mode 100644 python/paddle/fluid/tests/unittests/dist_mnist_lars.py
 create mode 100644 python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py

diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py
index 9540900b11..ff616ddbb2 100644
--- a/benchmark/fluid/args.py
+++ b/benchmark/fluid/args.py
@@ -142,5 +142,10 @@ def parse_args():
         choices=['reduce', 'all_reduce'],
         default='all_reduce',
         help='Specify the reduce strategy, can be reduce, all_reduce')
+    parser.add_argument(
+        '--fuse_broadcast_op',
+        action='store_true',
+        help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.'
+    )
     args = parser.parse_args()
     return args
diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py
index ddd9fe8098..5f3ce300ac 100644
--- a/benchmark/fluid/fluid_benchmark.py
+++ b/benchmark/fluid/fluid_benchmark.py
@@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
     else:
         build_strategy.reduce_strategy = fluid.BuildStrategy(
         ).ReduceStrategy.AllReduce
+    build_strategy.fuse_broadcast_op = args.fuse_broadcast_op
 
     avg_loss = train_args[0]
 
@@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
 
             if args.use_fake_data or args.use_reader_op:
                 try:
-
                     fetch_ret = exe.run(fetch_list)
                 except fluid.core.EOFException as eof:
                     break
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 0d90bf3cc1..2b8b82e74f 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -355,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind
 paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
 paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
 paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None)
+paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None))
+paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
 paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None))
 paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))
 paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))
diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index e0a3ef5a9c..17188ac5f3 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -16,12 +16,14 @@ if(WITH_GPU)
             dynload_cuda variable_visitor)
     nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
     nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
+    nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
 
 else()
     cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
              variable_visitor)
     cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
     cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
+    cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
 endif()
 
 cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
@@ -34,7 +36,7 @@ if(WITH_GPU)
 endif()
 
 cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
-        scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
+        scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
 
 if(WITH_GPU)
   cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
@@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
 cc_library(build_strategy SRCS build_strategy.cc DEPS
         graph_viz_pass multi_devices_graph_pass
         multi_devices_graph_print_pass multi_devices_graph_check_pass
-        fuse_elewise_add_act_pass)
+        fuse_elewise_add_act_pass multi_batch_merge_pass)
diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc
index 4fdab5cd94..5b5a10e227 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle.cc
@@ -48,16 +48,23 @@ void BroadcastOpHandle::RunImpl() {
     var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
   }
 
+  BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
+}
+
+void BroadcastOpHandle::BroadcastOneVar(
+    const VarHandle &in_var_handle,
+    const std::vector<VarHandle *> &out_var_handles,
+    const std::vector<const Scope *> &var_scopes) {
   auto *in_var =
-      var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
+      var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
   PADDLE_ENFORCE_NOT_NULL(in_var);
   Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
 
-  InitOutputValue(*in_var_handle, out_var_handles);
+  InitOutputValue(in_var_handle, out_var_handles);
 
   if (platform::is_cpu_place(in_tensor.place())) {
     for (auto *out_var_handle : out_var_handles) {
-      if (out_var_handle->IsTheSameVar(*in_var_handle)) {
+      if (out_var_handle->IsTheSameVar(in_var_handle)) {
         continue;
       }
       auto &out_p = out_var_handle->place_;
@@ -114,12 +121,12 @@ void BroadcastOpHandle::RunImpl() {
         }
       }
 
-      if (!out_handle->IsTheSameVar(*in_var_handle)) {
-        auto out_var = var_scopes.at(in_var_handle->scope_idx_)
+      if (!out_handle->IsTheSameVar(in_var_handle)) {
+        auto out_var = var_scopes.at(in_var_handle.scope_idx_)
                            ->FindVar(out_var_handles[0]->name_);
         paddle::framework::TensorCopy(
-            in_tensor, in_var_handle->place_,
-            *(dev_ctxes_.at(in_var_handle->place_)),
+            in_tensor, in_var_handle.place_,
+            *(dev_ctxes_.at(in_var_handle.place_)),
             &VariableVisitor::GetMutableTensor(out_var));
       }
     });
diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h
index fe4e733e43..020d351e89 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle.h
+++ b/paddle/fluid/framework/details/broadcast_op_handle.h
@@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase {
  protected:
   void RunImpl() override;
 
- private:
+  void BroadcastOneVar(const VarHandle &in_var_handle,
+                       const std::vector<VarHandle *> &out_var_handles,
+                       const std::vector<const Scope *> &var_scopes);
+
   std::vector<Scope *> local_scopes_;
   std::vector<platform::Place> places_;
 #ifdef PADDLE_WITH_CUDA
diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc
index 6a6b497fa8..fefd27fc86 100644
--- a/paddle/fluid/framework/details/build_strategy.cc
+++ b/paddle/fluid/framework/details/build_strategy.cc
@@ -121,6 +121,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
 
 USE_PASS(fuse_elewise_add_act_pass);
 USE_PASS(graph_viz_pass);
+USE_PASS(multi_batch_merge_pass);
 USE_PASS(multi_devices_pass);
 USE_PASS(multi_devices_check_pass);
 USE_PASS(multi_devices_print_pass);
diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h
index 02c4bea169..f3ffaf6ecd 100644
--- a/paddle/fluid/framework/details/build_strategy.h
+++ b/paddle/fluid/framework/details/build_strategy.h
@@ -69,6 +69,8 @@ struct BuildStrategy {
 
   bool enable_data_balance_{false};
 
+  bool fuse_broadcast_op_{false};
+
   // User normally doesn't need to call this API.
   // The PassBuilder allows for more customized insert, remove of passes
   // from python side.
diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc
new file mode 100644
index 0000000000..51dfa2d071
--- /dev/null
+++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc
@@ -0,0 +1,55 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
+#include "paddle/fluid/framework/details/container_cast.h"
+#include "paddle/fluid/framework/details/variable_visitor.h"
+#include "paddle/fluid/platform/profiler.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+void FusedBroadcastOpHandle::RunImpl() {
+  platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
+
+  if (places_.size() == 1UL) return;
+
+  auto in_var_handles = DynamicCast<VarHandle>(inputs_);
+  auto out_var_handles = DynamicCast<VarHandle>(outputs_);
+
+  WaitInputVarGenerated();
+
+  std::vector<const Scope *> var_scopes;
+  for (auto *s : local_scopes_) {
+    var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
+  }
+
+  size_t place_num = places_.size();
+  PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
+
+  for (size_t i = 0; i < in_var_handles.size(); ++i) {
+    BroadcastOneVar(
+        *in_var_handles[i],
+        std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
+                                 out_var_handles.begin() + (i + 1) * place_num),
+        var_scopes);
+  }
+}
+
+std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; }
+
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.h b/paddle/fluid/framework/details/fused_broadcast_op_handle.h
new file mode 100644
index 0000000000..e37259526a
--- /dev/null
+++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.h
@@ -0,0 +1,57 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "paddle/fluid/framework/details/broadcast_op_handle.h"
+#include "paddle/fluid/framework/details/multi_devices_helper.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/framework/selected_rows.h"
+#include "paddle/fluid/platform/device_context.h"
+
+#ifdef PADDLE_WITH_CUDA
+#include "paddle/fluid/platform/nccl_helper.h"
+#endif
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+struct FusedBroadcastOpHandle : public BroadcastOpHandle {
+ public:
+#ifdef PADDLE_WITH_CUDA
+  FusedBroadcastOpHandle(ir::Node *node,
+                         const std::vector<Scope *> local_scopes,
+                         const std::vector<platform::Place> &places,
+                         const platform::NCCLContextMap *nccl_ctx)
+      : BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {}
+#else
+  FusedBroadcastOpHandle(ir::Node* node, const std::vector<Scope*> local_scopes,
+                         const std::vector<platform::Place>& places)
+      : BroadcastOpHandle(node, local_scopes, places) {}
+#endif
+  std::string Name() const override;
+
+ protected:
+  void RunImpl() override;
+};
+
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
index ebd1d644bc..f2d5b182e5 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
@@ -21,6 +21,7 @@
 #include "paddle/fluid/framework/details/broadcast_op_handle.h"
 #include "paddle/fluid/framework/details/computation_op_handle.h"
 #include "paddle/fluid/framework/details/data_balance_op_handle.h"
+#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
 #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
 #include "paddle/fluid/framework/details/reduce_op_handle.h"
 #include "paddle/fluid/framework/details/rpc_op_handle.h"
@@ -347,7 +348,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
           BuildStrategy::GradientScaleStrategy::kCustomized) {
         // TODO(paddle-dev): Why is there no input for this op_handle?
         auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
-        CreateScaleLossGradOp(&result, loss_grad_name);
+        CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]);
       }
       // This assumes the backward generating code will ensure IsScaleLossOp
       // is true only for the op that scale the final scalar loss.
@@ -436,10 +437,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
   if ((use_gpu &&
        strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
       is_dist_train) {
-    for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
-      auto &to_bcast_set = bcast_var_name_set[dev_id];
-      for (auto &bcast_name : to_bcast_set) {
-        CreateBroadcastOp(&result, bcast_name, dev_id);
+    if (strategy_.fuse_broadcast_op_) {
+      CreateFusedBroadcastOp(&result, bcast_var_name_set);
+    } else {
+      for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
+        auto &to_bcast_set = bcast_var_name_set[dev_id];
+        for (auto &bcast_name : to_bcast_set) {
+          CreateBroadcastOp(&result, bcast_name, dev_id);
+        }
       }
     }
   }
@@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
   }
 }
 
+void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
+    ir::Graph *result,
+    const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
+#ifdef PADDLE_WITH_CUDA
+  auto *op_handle = new FusedBroadcastOpHandle(
+      result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
+      local_scopes_, places_, nccl_ctxs_);
+#else
+  auto *op_handle = new FusedBroadcastOpHandle(
+      result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
+      local_scopes_, places_);
+#endif
+  result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
+
+  for (size_t i = 0; i < places_.size(); ++i) {
+    auto &p = places_[i];
+    SetCommunicationContext(op_handle, p);
+  }
+
+  for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
+    for (auto &p_name : bcast_varnames[dev_id]) {
+      auto *in =
+          result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get();
+      op_handle->AddInput(in);
+      for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
+        auto &p = places_[out_dev_id];
+        auto &vars =
+            result->Get<GraphVars>(kGraphVars).at(out_dev_id).at(p_name);
+        auto *out_var = new VarHandle(
+            result->CreateEmptyNode(p_name, ir::Node::Type::kVariable),
+            vars.size(), out_dev_id, p_name, p);
+        vars.emplace_back(out_var);
+        op_handle->AddOutput(out_var);
+      }
+    }
+  }
+}
+
 void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
                                                     ir::Node *node,
                                                     int dev_id) const {
@@ -602,7 +645,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
 }
 
 void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
-    ir::Graph *result, const std::string &loss_grad_name) const {
+    ir::Graph *result, const std::string &loss_grad_name,
+    ir::Node *out_var_node) const {
   for (size_t i = 0; i < places_.size(); ++i) {
     // Insert ScaleCost OpHandle
     auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
@@ -617,10 +661,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
     // loss->pending_ops_.emplace_back(op_handle);
     // op_handle->inputs_.emplace_back(loss);
 
-    CreateOpOutput(
-        result, op_handle,
-        result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
-        places_[i], i);
+    CreateOpOutput(result, op_handle,
+                   result->CreateVarNode(out_var_node->Var()), places_[i], i);
   }
 }
 
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h
index cdf9f13cde..03b2de2f04 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h
@@ -61,7 +61,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
                               size_t num_places) const;
 
   void CreateScaleLossGradOp(ir::Graph *result,
-                             const std::string &loss_grad_name) const;
+                             const std::string &loss_grad_name,
+                             ir::Node *out_var_node) const;
 
   VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
                             int dst_dev_id) const;
@@ -78,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
   void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
                          size_t src_dev_id) const;
 
+  void CreateFusedBroadcastOp(
+      ir::Graph *result,
+      const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
+
   bool IsSparseGradient(const std::string &og) const;
 
   size_t GetAppropriateDeviceID(
diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index a145b2fafe..ce006b7a3f 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -36,6 +36,7 @@ pass_library(fc_lstm_fuse_pass inference)
 pass_library(embedding_fc_lstm_fuse_pass inference)
 pass_library(fc_gru_fuse_pass inference)
 pass_library(seq_concat_fc_fuse_pass inference)
+pass_library(multi_batch_merge_pass base)
 pass_library(conv_bn_fuse_pass inference)
 pass_library(seqconv_eltadd_relu_fuse_pass inference)
 if(WITH_MKLDNN)
diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc
index 11102bc776..265a128e95 100644
--- a/paddle/fluid/framework/ir/graph.cc
+++ b/paddle/fluid/framework/ir/graph.cc
@@ -27,14 +27,20 @@ namespace ir {
 Graph::Graph(const ProgramDesc &program) : program_(program) {
   // Make the nodes id start from 0.
   Node::ResetId();
+  auto var_nodes = InitFromProgram(program_);
+  ResolveHazard(var_nodes);
+}
 
+std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
+    const ProgramDesc &program) {
   VLOG(3) << "block in program:" << program_.Size();
   std::unordered_map<std::string, VarDesc *> all_vars;
+  // var nodes for each var name, will have multiple versions in SSA
+  std::map<std::string, std::vector<ir::Node *>> var_nodes;
   for (auto *var : program.Block(0).AllVars()) {
     all_vars.emplace(var->Name(), var);
   }
 
-  std::map<std::string, std::vector<ir::Node *>> var_nodes;
   for (auto *op : program.Block(0).AllOps()) {
     ir::Node *node = CreateOpNode(op);
     // For input args, reuse the same var name if it was created before.
@@ -72,7 +78,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
       var->inputs.push_back(node);
     }
   }
+  return std::move(var_nodes);
+}
 
+void Graph::ResolveHazard(
+    const std::map<std::string, std::vector<ir::Node *>> &var_nodes) {
   /**
    * We should handle write after read(WAR) and write after write(WAW) here.
    * Because some of the operators of the program can be executed parallelly.
@@ -91,6 +101,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
     auto it_old = versions.rbegin();
     ++it_old;
     for (; it_old != versions.rend(); it_new = it_old, ++it_old) {
+      VLOG(3) << "deal with var: " << (*it_new)->Name();
       ir::Node *write_op =
           (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
       const auto &read_ops = (*it_old)->outputs;
diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h
index ab687e760a..9d7aa5d32d 100644
--- a/paddle/fluid/framework/ir/graph.h
+++ b/paddle/fluid/framework/ir/graph.h
@@ -160,6 +160,12 @@ class Graph {
     return nullptr;
   }
 
+  std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
+      const ProgramDesc &program);
+
+  void ResolveHazard(
+      const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
+
  private:
   // This method takes ownership of `node`.
   ir::Node *AddNode(ir::Node *node) {
diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc
new file mode 100644
index 0000000000..bd5b76426e
--- /dev/null
+++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc
@@ -0,0 +1,315 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/multi_batch_merge_pass.h"
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "paddle/fluid/framework/ir/graph_helper.h"
+#include "paddle/fluid/framework/op_proto_maker.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+static const char kNumRepeats[] = "num_repeats";
+typedef std::unordered_map<std::string, std::vector<ir::Node*>> SSAVarList;
+
+ir::Node* SameNameVar(std::unordered_set<ir::Node*> all, ir::Node* target) {
+  for (auto n : all) {
+    if (target->IsVar() && target->Name() == n->Name()) {
+      return n;
+    }
+  }
+  return nullptr;
+}
+
+VarDesc CopyVarDesc(VarDesc* var_desc) {
+  VarDesc repeated_var(var_desc->Name());
+  // copy other variable attributes
+  if (var_desc->GetType() != proto::VarType::READER) {
+    repeated_var.SetType(var_desc->GetType());
+    repeated_var.SetShape(var_desc->GetShape());
+    repeated_var.SetDataType(var_desc->GetDataType());
+    repeated_var.SetLoDLevel(var_desc->GetLoDLevel());
+    repeated_var.SetPersistable(var_desc->Persistable());
+  } else {
+    // TODO(typhoonzero): copy reader var
+  }
+  return repeated_var;
+}
+
+VarDesc UpdateGradVarDesc(
+    VarDesc* var_desc, int repeat,
+    const std::unordered_set<std::string>& grad_names,
+    const std::unordered_set<std::string>& bn_vars_need_rename) {
+  if (grad_names.find(var_desc->Name()) != grad_names.end() ||
+      bn_vars_need_rename.find(var_desc->Name()) != bn_vars_need_rename.end()) {
+    std::string new_gname =
+        string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat);
+    VarDesc repeated_var = CopyVarDesc(var_desc);
+    repeated_var.SetName(new_gname);
+    VLOG(3) << "update " << var_desc->Name() << " to repeat " << repeat;
+    return repeated_var;
+  }
+  return *var_desc;
+}
+
+std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
+    std::unique_ptr<Graph> graph) const {
+  int num_repeats = Get<const int>(kNumRepeats);
+  std::vector<Node*> forward_backward_ops;
+  std::vector<Node*> optimize_ops;
+  std::vector<Node*> lr_ops;  // ops other than forward/backward/optimize
+  std::unordered_set<std::string> grad_names;
+
+  std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
+  auto origin_nodes = graph->ReleaseNodes();
+  VLOG(3) << "origin nodes count: " << origin_nodes.size();
+  ir::Graph& result = *graph;
+
+  // 1. record op nodes of different roles
+  for (auto node : nodes) {
+    if (node->IsVar()) continue;
+    int op_role = boost::get<int>(node->Op()->GetAttr(
+        framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
+    if ((op_role == static_cast<int>(framework::OpRole::kForward)) ||
+        (op_role & static_cast<int>(framework::OpRole::kBackward)) ||
+        (op_role & static_cast<int>(framework::OpRole::kLoss))) {
+      forward_backward_ops.push_back(node);
+    } else if ((op_role & static_cast<int>(framework::OpRole::kOptimize)) ||
+               (op_role & static_cast<int>(framework::OpRole::kDist)) ||
+               (op_role & static_cast<int>(framework::OpRole::kRPC))) {
+      optimize_ops.push_back(node);
+      auto op_role_var = node->Op()->GetNullableAttr(
+          OpProtoAndCheckerMaker::OpRoleVarAttrName());
+      auto op_role_vars = boost::get<std::vector<std::string>>(op_role_var);
+      for (size_t i = 0; i < op_role_vars.size(); i += 2) {
+        grad_names.insert(op_role_vars[i + 1]);
+      }
+    } else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) {
+      lr_ops.push_back(node);
+    } else {  // NOLINT
+      PADDLE_THROW("Invalid op_role: %d", static_cast<int>(op_role));
+    }
+  }
+
+  // 2. copy forward backward
+  ir::Node* prev_repeat_last_op_node = nullptr;
+  // record origin_grad -> repeated grad list map.
+  std::map<ir::Node*, std::vector<ir::Node*>> grad_repeated_map;
+  std::map<std::string, std::vector<ir::Node*>> created;
+  std::unordered_set<std::string> bn_vars_need_rename;
+  for (int i = 0; i < num_repeats; ++i) {
+    std::unordered_set<ir::Node*> copied;
+    for (size_t node_idx = 0; node_idx < forward_backward_ops.size();
+         ++node_idx) {
+      auto node = forward_backward_ops[node_idx];
+      OpDesc repeated_op(*(node->Op()), node->Op()->Block());
+      // 3. rename grad outputs to current repeat.
+      for (auto outname : repeated_op.OutputArgumentNames()) {
+        if (grad_names.find(outname) != grad_names.end()) {
+          std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i);
+          repeated_op.RenameOutput(outname, new_gname);
+        }
+      }
+      // 3.5 let batch_norm ops use independent vars, note batch_norm_grad do
+      // not need this update
+      if (node->Name() == "batch_norm") {
+        // NOTE: assume bn op created by layers use save var as output mean and
+        // variance
+        std::string new_mean_name =
+            string::Sprintf("%s.repeat.%d", repeated_op.Input("Mean")[0], i);
+        std::string new_var_name = string::Sprintf(
+            "%s.repeat.%d", repeated_op.Input("Variance")[0], i);
+        bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]);
+        bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]);
+        VLOG(3) << "renaming " << repeated_op.Input("Mean")[0] << " to "
+                << new_mean_name;
+        repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name);
+        repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name);
+        repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0],
+                                 new_mean_name);
+        repeated_op.RenameOutput(repeated_op.Output("VarianceOut")[0],
+                                 new_var_name);
+      }
+
+      // 3.9 do copy
+      auto repeated_node = result.CreateOpNode(&repeated_op);
+      copied.insert(node);
+
+      // 4. add deps between repeats
+      if (node_idx == forward_backward_ops.size() - 1) {
+        prev_repeat_last_op_node = repeated_node;
+      }
+      if (node_idx == 0 && prev_repeat_last_op_node) {
+        auto* depvar = result.CreateControlDepVar();
+        prev_repeat_last_op_node->outputs.push_back(depvar);
+        depvar->inputs.push_back(prev_repeat_last_op_node);
+        repeated_node->inputs.push_back(depvar);
+        depvar->outputs.push_back(repeated_node);
+      }
+
+      for (auto in_node : node->inputs) {
+        if (in_node->IsCtrlVar()) {
+          continue;
+        }
+        ir::Node* var = nullptr;
+        auto updated_var = UpdateGradVarDesc(in_node->Var(), i, grad_names,
+                                             bn_vars_need_rename);
+        // should be initialized by startup, how to initilize tensor in the
+        // scope?
+        if (node->Name() == "batch_norm" &&
+            bn_vars_need_rename.find(in_node->Name()) !=
+                bn_vars_need_rename.end()) {
+          // Create bn mean/variance for each repeat
+          var = result.CreateVarNode(&updated_var);
+          created[updated_var.Name()].push_back(var);
+          copied.insert(in_node);
+          repeated_node->inputs.push_back(var);
+          var->outputs.push_back(repeated_node);
+          continue;
+        }
+
+        // for other ops
+        if (in_node->inputs.empty() && i > 0) {
+          // do not copy head vars (inputs, params) in repeats > 0
+          var = created.at(in_node->Name()).back();
+        } else {
+          if (copied.find(in_node) == copied.end()) {
+            var = result.CreateVarNode(&updated_var);
+            if (grad_names.find(in_node->Var()->Name()) != grad_names.end()) {
+              grad_repeated_map[in_node].push_back(var);
+            }
+            copied.insert(in_node);
+            created[updated_var.Name()].push_back(var);
+          } else {
+            var = created.at(updated_var.Name()).back();
+          }
+        }
+        repeated_node->inputs.push_back(var);
+        var->outputs.push_back(repeated_node);
+      }
+      for (auto out_node : node->outputs) {
+        if (out_node->IsCtrlVar()) {
+          continue;
+        }
+        ir::Node* var = nullptr;
+        auto updated_var = UpdateGradVarDesc(out_node->Var(), i, grad_names,
+                                             bn_vars_need_rename);
+        if (copied.find(out_node) == copied.end()) {
+          var = result.CreateVarNode(&updated_var);
+          if (grad_names.find(out_node->Var()->Name()) != grad_names.end()) {
+            grad_repeated_map[out_node].push_back(var);
+          }
+          copied.insert(out_node);
+          created[updated_var.Name()].push_back(var);
+        } else {
+          var = created.at(updated_var.Name()).back();
+        }
+        repeated_node->outputs.push_back(var);
+        var->inputs.push_back(repeated_node);
+      }
+    }
+  }
+
+  // 5. create GRAD merge op node
+  for (auto kv : grad_repeated_map) {
+    OpDesc sum_op;
+    sum_op.SetType("sum");
+    std::vector<std::string> repeated_grad_names;
+    for (auto r : kv.second) {
+      repeated_grad_names.push_back(r->Var()->Name());
+    }
+    sum_op.SetInput("X", repeated_grad_names);
+    sum_op.SetOutput("Out", {kv.first->Var()->Name()});
+    sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
+                   static_cast<int>(OpRole::kBackward));
+    auto sum_op_node = result.CreateOpNode(&sum_op);
+    for (auto r : kv.second) {
+      sum_op_node->inputs.push_back(r);
+      r->outputs.push_back(sum_op_node);
+    }
+    auto sum_out_var_node = result.CreateVarNode(kv.first->Var());
+    sum_op_node->outputs.push_back(sum_out_var_node);
+    sum_out_var_node->inputs.push_back(sum_op_node);
+    created[sum_out_var_node->Name()].push_back(sum_out_var_node);
+
+    OpDesc scale_op;
+    scale_op.SetType("scale");
+    scale_op.SetInput("X", {sum_out_var_node->Var()->Name()});
+    // NOTE: inplace scale.
+    scale_op.SetOutput("Out", {sum_out_var_node->Var()->Name()});
+    scale_op.SetAttr("scale", static_cast<float>(1.0f / num_repeats));
+    scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
+                     static_cast<int>(OpRole::kBackward));
+    auto scale_op_node = result.CreateOpNode(&scale_op);
+    scale_op_node->inputs.push_back(sum_out_var_node);
+    sum_out_var_node->outputs.push_back(scale_op_node);
+    auto scale_out_var_node = result.CreateVarNode(sum_out_var_node->Var());
+    scale_op_node->outputs.push_back(scale_out_var_node);
+    scale_out_var_node->inputs.push_back(scale_op_node);
+    created[scale_out_var_node->Name()].push_back(scale_out_var_node);
+  }
+  // 6. add optimize ops
+  {
+    auto copy_node = [&result, &created](ir::Node* node) {
+      auto op_node = result.CreateOpNode(node->Op());
+      // copy op ins/outs
+      // NOTE: for send/recv ops, the OpDesc uses ctrldepvar to describe
+      // dependencies, so create those depvars if OpDesc have in/outs.
+      for (auto in_node : node->inputs) {
+        if (in_node->IsCtrlVar() && !in_node->Var()) {
+          continue;
+        }
+        ir::Node* var = nullptr;
+        if (created.find(in_node->Name()) == created.end()) {
+          var = result.CreateVarNode(in_node->Var());
+          created[in_node->Name()].push_back(var);
+        } else {
+          var = created.at(in_node->Name()).back();
+        }
+        op_node->inputs.push_back(var);
+        var->outputs.push_back(op_node);
+      }
+      for (auto out_node : node->outputs) {
+        if (out_node->IsCtrlVar() && !out_node->Var()) {
+          continue;
+        }
+        auto var = result.CreateVarNode(out_node->Var());
+        created[out_node->Name()].push_back(var);
+        op_node->outputs.push_back(var);
+        var->inputs.push_back(op_node);
+      }
+    };
+    for (auto node : lr_ops) {
+      copy_node(node);
+    }
+    for (auto node : optimize_ops) {
+      copy_node(node);
+    }
+  }
+
+  result.ResolveHazard(created);
+  return graph;
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(multi_batch_merge_pass, paddle::framework::ir::BatchMergePass)
+    .RequirePassAttr(paddle::framework::ir::kNumRepeats);
diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.h b/paddle/fluid/framework/ir/multi_batch_merge_pass.h
new file mode 100644
index 0000000000..c1e5aef20d
--- /dev/null
+++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.h
@@ -0,0 +1,44 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+// BatchMergePass is used to copy forward and backward ops for several
+// times to run several batches to simulate large batch size training
+// as if we have more than 1 GPUs.
+// User can define how many batches to run, gradients will be merged
+// through those repeats, and then do optimization using merged gradients.
+// This pass is extremely useful when doing large batch-size distributed
+// sync training, we can simulate even large batch size as if we have more
+// GPUs.
+
+class BatchMergePass : public Pass {
+ public:
+  virtual ~BatchMergePass() {}
+
+ protected:
+  std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 3368ae2ee4..cffb96bedf 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -109,18 +109,9 @@ ParallelExecutor::ParallelExecutor(
   if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
     BCastParamsToDevices(bcast_vars);
   }
-  // Startup Program has been run. All local scopes has correct parameters.
+// Startup Program has been run. All local scopes has correct parameters.
 
-  // Step 2. Create vars in each scope;
-  std::vector<details::VariableInfo> var_infos;
-  for (auto *var : main_program.Block(0).AllVars()) {
-    var_infos.emplace_back();
-    var_infos.back().name_ = var->Name();
-    var_infos.back().type_ = var->GetType();
-    var_infos.back().persistable_ = var->Persistable();
-  }
-
-// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
+// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
 // ncclOp
 #ifdef PADDLE_WITH_CUDA
   std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
@@ -156,6 +147,17 @@ ParallelExecutor::ParallelExecutor(
                            params, member_->local_scopes_, member_->use_cuda_);
 #endif
 
+  // Step 3. Create vars in each scope. Passes may also create new vars.
+  //         skip control vars and empty vars
+  std::vector<details::VariableInfo> var_infos;
+  for (auto &node : graph->Nodes()) {
+    if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
+      var_infos.emplace_back();
+      var_infos.back().name_ = node->Var()->Name();
+      var_infos.back().type_ = node->Var()->GetType();
+      var_infos.back().persistable_ = node->Var()->Persistable();
+    }
+  }
   // If the loss_var_name is given, the number of graph should be only one.
   if (loss_var_name.size()) {
     PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
diff --git a/paddle/fluid/operators/lars_momentum_op.cc b/paddle/fluid/operators/lars_momentum_op.cc
new file mode 100644
index 0000000000..a8dda93902
--- /dev/null
+++ b/paddle/fluid/operators/lars_momentum_op.cc
@@ -0,0 +1,86 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/lars_momentum_op.h"
+#include "paddle/fluid/operators/momentum_op.h"
+
+namespace paddle {
+namespace operators {
+
+class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() override {
+    AddInput("Param",
+             "(LoDTensor, default LoDTensor<float>) "
+             "Input parameter that has to be updated");
+    AddInput("Grad",
+             "(LoDTensor, default LoDTensor<float>) "
+             "Input gradient of the parameter");
+    AddInput("Velocity",
+             "(LoDTensor, default LoDTensor<float>) "
+             "Input velocity (corresponding to the parameter) "
+             "that has to be updated");
+    AddInput("LearningRate",
+             "(LoDTensor, default LoDTensor<float>) "
+             "Input learning rate");
+
+    AddOutput("ParamOut",
+              "(LoDTensor) This output is updated parameter. "
+              "It shared memory with Input(Param).");
+    AddOutput("VelocityOut",
+              "(LoDTensor) This output is updated velocity. "
+              "It shared memory with Input(Velocity).");
+
+    AddAttr<float>("mu", "(float) Momentum coefficient");
+    AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
+        .SetDefault(0.001);
+    AddAttr<float>("lars_weight_decay",
+                   "(float, default 0.0005) LARS weight decay")
+        .SetDefault(0.0005);
+
+    AddComment(R"DOC(
+Lars Momentum Optimizer.
+
+This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
+weight using a local learning rate:
+
+$$
+local\_lr = \eta  * 
+    \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
+velocity = mu * velocity + 
+    local\_lr * (grad + \beta * param) \\
+param = param - velocity. \\
+$$
+
+Note that we use lars_weight_decay here to decay weights, you may need not to
+use L2 regularizers in case of using LARS.
+
+)DOC");
+  }
+};
+
+class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
+ public:
+  void operator()(const framework::OpDesc &op_desc,
+                  framework::BlockDesc *block) const override {}
+};
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
+                  paddle::framework::EmptyGradOpMaker,
+                  ops::LarsMomentumOpVarTypeInference);
+REGISTER_OP_CPU_KERNEL(lars_momentum, ops::LarsMomentumOpKernel<float>,
+                       ops::LarsMomentumOpKernel<double>);
diff --git a/paddle/fluid/operators/lars_momentum_op.cu b/paddle/fluid/operators/lars_momentum_op.cu
new file mode 100644
index 0000000000..eb346851a2
--- /dev/null
+++ b/paddle/fluid/operators/lars_momentum_op.cu
@@ -0,0 +1,94 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/lars_momentum_op.h"
+
+namespace paddle {
+namespace operators {
+
+template <typename T>
+__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
+                                   const T* learning_rate, const T mu,
+                                   const int64_t num, const T lars_coeff,
+                                   const T lars_weight_decay, const T* p_norm,
+                                   const T* g_norm, T* p_out, T* v_out) {
+  T lr = learning_rate[0];
+  T local_lr = learning_rate[0];
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
+       i += blockDim.x * gridDim.x) {
+    if (p_norm[0] > 0 && g_norm[0] > 0) {
+      local_lr = lr * lars_coeff * p_norm[0] /
+                 (g_norm[0] + lars_weight_decay * p_norm[0]);
+    }
+    T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
+    v_out[i] = v_new;
+    p_out[i] = p[i] - v_new;
+  }
+}
+
+template <typename DeviceContext, typename T>
+class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
+    auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
+    auto param = ctx.Input<framework::LoDTensor>("Param");
+    auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
+    auto grad = ctx.Input<framework::LoDTensor>("Grad");
+    auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
+
+    T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
+    T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
+
+    T mu = static_cast<T>(ctx.Attr<float>("mu"));
+    T lars_coeff = ctx.Attr<float>("lars_coeff");
+    T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
+
+    auto* p = param->data<T>();
+    auto* v = velocity->data<T>();
+    auto* g = grad->data<T>();
+    auto* lr = learning_rate->data<T>();
+
+    int block = 512;
+    int grid = (param->numel() + block - 1) / block;
+
+    auto eigen_p = framework::EigenVector<T>::Flatten(*param);
+    auto eigen_g = framework::EigenVector<T>::Flatten(*grad);
+    // calculate norms using eigein and launch the kernel.
+    framework::Tensor p_norm_t, g_norm_t;
+    p_norm_t.Resize({1});
+    g_norm_t.Resize({1});
+    auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
+    auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
+    auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
+    auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
+
+    auto* place = ctx.template device_context<DeviceContext>().eigen_device();
+    ep_norm.device(*place) = eigen_p.square().sum().sqrt();
+    eg_norm.device(*place) = eigen_g.square().sum().sqrt();
+    MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
+        p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
+        p_norm_data, g_norm_data, p_out, v_out);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(
+    lars_momentum,
+    ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
+    ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);
diff --git a/paddle/fluid/operators/lars_momentum_op.h b/paddle/fluid/operators/lars_momentum_op.h
new file mode 100644
index 0000000000..e85be99fc4
--- /dev/null
+++ b/paddle/fluid/operators/lars_momentum_op.h
@@ -0,0 +1,72 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/fluid/framework/eigen.h"
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+template <typename T>
+class LarsMomentumOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
+    auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
+    auto param = ctx.Input<framework::LoDTensor>("Param");
+    auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
+    auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
+    auto* grad_var = ctx.InputVar("Grad");
+    // only support dense for now.
+    PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>());
+    auto grad = ctx.Input<framework::LoDTensor>("Grad");
+
+    param_out->mutable_data<T>(ctx.GetPlace());
+    velocity_out->mutable_data<T>(ctx.GetPlace());
+
+    T mu = static_cast<T>(ctx.Attr<float>("mu"));
+    T lars_coeff = ctx.Attr<float>("lars_coeff");
+    T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
+
+    auto p_out = framework::EigenVector<T>::Flatten(*param_out);
+    auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
+
+    auto p = framework::EigenVector<T>::Flatten(*param);
+    auto v = framework::EigenVector<T>::Flatten(*velocity);
+    auto g = framework::EigenVector<T>::Flatten(*grad);
+    auto* lr = learning_rate->data<T>();
+
+    framework::Tensor p_norm_t, g_norm_t;
+    p_norm_t.Resize({1});
+    g_norm_t.Resize({1});
+    p_norm_t.mutable_data<T>(ctx.GetPlace());
+    g_norm_t.mutable_data<T>(ctx.GetPlace());
+    auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
+    auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
+
+    ep_norm = p.square().sum().sqrt();
+    eg_norm = g.square().sum().sqrt();
+    T local_lr = lr[0];
+    if (ep_norm(0) > 0 && eg_norm(0) > 0) {
+      local_lr = lr[0] * lars_coeff * ep_norm(0) /
+                 (eg_norm(0) + lars_weight_decay * ep_norm(0));
+    }
+    v_out = v * mu + local_lr * (g + lars_weight_decay * p);
+    p_out = p - v_out;
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc
index 12b916fceb..7f0b51580a 100644
--- a/paddle/fluid/operators/momentum_op.cc
+++ b/paddle/fluid/operators/momentum_op.cc
@@ -19,54 +19,6 @@ namespace operators {
 
 using Tensor = framework::Tensor;
 
-class MomentumOp : public framework::OperatorWithKernel {
- public:
-  using framework::OperatorWithKernel::OperatorWithKernel;
-
- protected:
-  void InferShape(framework::InferShapeContext* ctx) const override {
-    PADDLE_ENFORCE(ctx->HasInput("Param"),
-                   "Input(param) of Momentum should not be null.");
-    PADDLE_ENFORCE(ctx->HasInput("Grad"),
-                   "Input(grad) of Momentum should not be null.");
-    PADDLE_ENFORCE(ctx->HasInput("Velocity"),
-                   "Input(velocity) of Momentum should not be null.");
-    PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
-                   "Input(LearningRate) of Momentum should not be null.");
-    PADDLE_ENFORCE(
-        ctx->GetInputsVarType("Param").front() ==
-            framework::proto::VarType::LOD_TENSOR,
-        "The input var's type should be LoDTensor, but the received is %s",
-        ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
-
-    PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
-                   "Output(ParamOut) of Momentum should not be null.");
-    PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
-                   "Output(VelocityOut) of Momentum should not be null.");
-
-    auto param_dim = ctx->GetInputDim("Param");
-    if (ctx->GetInputsVarType("Grad")[0] ==
-        framework::proto::VarType::LOD_TENSOR) {
-      PADDLE_ENFORCE_EQ(
-          param_dim, ctx->GetInputDim("Grad"),
-          "Param and Grad input of MomentumOp should have the same dimension.");
-      PADDLE_ENFORCE_EQ(
-          param_dim, ctx->GetInputDim("Velocity"),
-          "Param and Velocity of MomentumOp should have the same dimension.");
-    }
-    PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
-                      "Learning_rate should be a scalar");
-
-    ctx->SetOutputDim("ParamOut", param_dim);
-    ctx->SetOutputDim("VelocityOut", param_dim);
-  }
-  framework::OpKernelType GetExpectedKernelType(
-      const framework::ExecutionContext& ctx) const override {
-    auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
-    return framework::OpKernelType(input_data_type, ctx.GetPlace());
-  }
-};
-
 class MomentumOpInferVarType : public framework::VarTypeInference {
  public:
   void operator()(const framework::OpDesc& op_desc,
diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h
index 6b4d00f56c..71f079e4d9 100644
--- a/paddle/fluid/operators/momentum_op.h
+++ b/paddle/fluid/operators/momentum_op.h
@@ -28,6 +28,54 @@ using framework::SelectedRows;
 struct NoNesterov;
 struct UseNesterov;
 
+class MomentumOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("Param"),
+                   "Input(param) of Momentum should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Grad"),
+                   "Input(grad) of Momentum should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("Velocity"),
+                   "Input(velocity) of Momentum should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
+                   "Input(LearningRate) of Momentum should not be null.");
+    PADDLE_ENFORCE(
+        ctx->GetInputsVarType("Param").front() ==
+            framework::proto::VarType::LOD_TENSOR,
+        "The input var's type should be LoDTensor, but the received is %s",
+        ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
+
+    PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
+                   "Output(ParamOut) of Momentum should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
+                   "Output(VelocityOut) of Momentum should not be null.");
+
+    auto param_dim = ctx->GetInputDim("Param");
+    if (ctx->GetInputsVarType("Grad")[0] ==
+        framework::proto::VarType::LOD_TENSOR) {
+      PADDLE_ENFORCE_EQ(
+          param_dim, ctx->GetInputDim("Grad"),
+          "Param and Grad input of MomentumOp should have the same dimension.");
+      PADDLE_ENFORCE_EQ(
+          param_dim, ctx->GetInputDim("Velocity"),
+          "Param and Velocity of MomentumOp should have the same dimension.");
+    }
+    PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
+                      "Learning_rate should be a scalar");
+
+    ctx->SetOutputDim("ParamOut", param_dim);
+    ctx->SetOutputDim("VelocityOut", param_dim);
+  }
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
+    return framework::OpKernelType(input_data_type, ctx.GetPlace());
+  }
+};
+
 template <typename T>
 class CPUDenseMomentumFunctor {
  private:
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index 339a7c98c6..5f15a29f4c 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -645,9 +645,13 @@ All parameter, weight, gradient are variables in Paddle.
 
   py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
   pass.def(py::init())
-      .def("set_str", [](ir::Pass &self, const std::string &name,
-                         const std::string &attr) {
-        self.Set<std::string>(name, new std::string(attr));
+      .def(
+          "set_str",
+          [](ir::Pass &self, const std::string &name, const std::string &attr) {
+            self.Set<std::string>(name, new std::string(attr));
+          })
+      .def("set_int", [](ir::Pass &self, const std::string &name, int val) {
+        self.Set<const int>(name, new int(val));
       });
 
   py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py
index dfd801a098..149224bb68 100644
--- a/python/paddle/fluid/layers/learning_rate_scheduler.py
+++ b/python/paddle/fluid/layers/learning_rate_scheduler.py
@@ -27,7 +27,7 @@ from . import nn
 from . import ops
 from . import tensor
 from ..initializer import init_on_cpu
-from ..framework import default_main_program, Parameter, unique_name
+from ..framework import default_main_program, Parameter, unique_name, name_scope
 
 __all__ = [
     'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
@@ -332,14 +332,16 @@ def append_LARS(params_grads, learning_rate, weight_decay):
             return grad_norm + weight_decay * param_norm
 
     for param, grad in params_grads:
-        param_lr = param.optimize_attr['learning_rate']
-        param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param)))
-        grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad)))
-        if type(param_lr) == float and param_lr == 1.0:
-            decayed_lr = learning_rate * param_norm \
-                / _balanced_weight(param_norm, grad_norm)
-        else:
-            decayed_lr = learning_rate * param_lr * param_norm \
-                / _balanced_weight(param_norm, grad_norm)
-        # set back param local learning rate
-        param.optimize_attr['learning_rate'] = decayed_lr
+        with param.block.program.optimized_guard(
+            [param, grad]), name_scope("optimizer"):
+            param_lr = param.optimize_attr['learning_rate']
+            param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param)))
+            grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad)))
+            if type(param_lr) == float and param_lr == 1.0:
+                decayed_lr = learning_rate * param_norm \
+                    / _balanced_weight(param_norm, grad_norm)
+            else:
+                decayed_lr = learning_rate * param_lr * param_norm \
+                    / _balanced_weight(param_norm, grad_norm)
+            # set back param local learning rate
+            param.optimize_attr['learning_rate'] = decayed_lr
diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py
index 6ea280c733..7e2364a5a8 100644
--- a/python/paddle/fluid/optimizer.py
+++ b/python/paddle/fluid/optimizer.py
@@ -14,6 +14,7 @@
 
 from __future__ import print_function
 import re
+import sys
 from collections import defaultdict
 from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
 from . import framework
@@ -32,7 +33,8 @@ __all__ = [
     'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
     'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
     'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
-    'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer'
+    'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
+    'LarsMomentumOptimizer'
 ]
 
 
@@ -105,7 +107,6 @@ class Optimizer(object):
         param = param_and_grad[0]
         param_lr = param.optimize_attr['learning_rate']
         if type(param_lr) == Variable:
-            print("returns updated param lr ", param_lr)
             return param_lr
         else:
             if param_lr == 1.0:
@@ -400,6 +401,91 @@ class MomentumOptimizer(Optimizer):
         return momentum_op
 
 
+class LarsMomentumOptimizer(Optimizer):
+    """
+    Momentum optimizer with LARS support
+
+    The update equations are as follows:
+
+    .. math::
+
+        & local\_learning\_rate = learning\_rate * lars\_coeff * \\
+          \\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||}
+
+        & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param)
+
+        & param = param - velocity
+
+    Args:
+        learning_rate (float|Variable): the learning rate used to update parameters. \
+        Can be a float value or a Variable with one float value as data element.
+        momentum (float): momentum factor
+        lars_coeff (float): defines how much we trust the layer to change its weights.
+        lars_weight_decay (float): weight decay coefficient for decaying using LARS.
+        regularization: A Regularizer, such as
+                        fluid.regularizer.L2DecayRegularizer.
+        name: A optional name prefix.
+        
+
+    Examples:
+        .. code-block:: python
+
+            optimizer = fluid.optimizer.LarsMomentum(learning_rate=0.2, momentum=0.1, lars_weight_decay=0.001)
+            optimizer.minimize(cost)
+    """
+    _velocity_acc_str = "velocity"
+
+    def __init__(self,
+                 learning_rate,
+                 momentum,
+                 lars_coeff=0.001,
+                 lars_weight_decay=0.0005,
+                 regularization=None,
+                 name=None):
+        assert learning_rate is not None
+        assert momentum is not None
+        super(LarsMomentumOptimizer, self).__init__(
+            learning_rate=learning_rate,
+            regularization=regularization,
+            name=name)
+        self.type = "lars_momentum"
+        self._momentum = momentum
+        self._lars_coeff = float(lars_coeff)
+        self._lars_weight_decay = float(lars_weight_decay)
+
+    def _create_accumulators(self, block, parameters):
+        assert isinstance(block, framework.Block)
+
+        for p in parameters:
+            self._add_accumulator(self._velocity_acc_str, p)
+
+    def _append_optimize_op(self, block, param_and_grad):
+        assert isinstance(block, framework.Block)
+
+        velocity_acc = self._get_accumulator(self._velocity_acc_str,
+                                             param_and_grad[0])
+        # create the momentum optimize op
+        momentum_op = block.append_op(
+            type=self.type,
+            inputs={
+                "Param": param_and_grad[0],
+                "Grad": param_and_grad[1],
+                "Velocity": velocity_acc,
+                "LearningRate": self._create_param_lr(param_and_grad)
+            },
+            outputs={
+                "ParamOut": param_and_grad[0],
+                "VelocityOut": velocity_acc
+            },
+            attrs={
+                "mu": self._momentum,
+                "lars_coeff": self._lars_coeff,
+                "lars_weight_decay": self._lars_weight_decay
+            })
+
+        return momentum_op
+
+
 class AdagradOptimizer(Optimizer):
     """
     **Adaptive Gradient Algorithm (Adagrad)**
@@ -1221,6 +1307,7 @@ DecayedAdagrad = DecayedAdagradOptimizer
 Adadelta = AdadeltaOptimizer
 RMSProp = RMSPropOptimizer
 Ftrl = FtrlOptimizer
+LarsMomentum = LarsMomentumOptimizer
 
 
 class ModelAverage(Optimizer):
diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py
index 877d21ae88..01e9795d8b 100644
--- a/python/paddle/fluid/tests/unittests/dist_mnist.py
+++ b/python/paddle/fluid/tests/unittests/dist_mnist.py
@@ -95,7 +95,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
 
         # Reader
         train_reader = paddle.batch(
-            paddle.dataset.mnist.train(), batch_size=batch_size)
+            paddle.dataset.mnist.test(), batch_size=batch_size)
         test_reader = paddle.batch(
             paddle.dataset.mnist.test(), batch_size=batch_size)
         opt.minimize(avg_cost)
diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py
new file mode 100644
index 0000000000..d386e75fd8
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py
@@ -0,0 +1,80 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import numpy as np
+import argparse
+import time
+import math
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.profiler as profiler
+from paddle.fluid import core
+import unittest
+from multiprocessing import Process
+import os
+import signal
+from functools import reduce
+from test_dist_base import TestDistRunnerBase, runtime_main
+from dist_mnist import cnn_model
+
+DTYPE = "float32"
+
+
+def test_merge_reader(repeat_batch_size=8):
+    orig_reader = paddle.dataset.mnist.test()
+    record_batch = []
+    b = 0
+    for d in orig_reader():
+        if b >= repeat_batch_size:
+            break
+        record_batch.append(d)
+        b += 1
+    while True:
+        for d in record_batch:
+            yield d
+
+
+class TestDistMnist2x2(TestDistRunnerBase):
+    def get_model(self, batch_size=2):
+        # Input data
+        images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
+        label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+
+        # Train program
+        predict = cnn_model(images)
+        cost = fluid.layers.cross_entropy(input=predict, label=label)
+        avg_cost = fluid.layers.mean(x=cost)
+
+        # Evaluator
+        batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
+        batch_acc = fluid.layers.accuracy(
+            input=predict, label=label, total=batch_size_tensor)
+
+        inference_program = fluid.default_main_program().clone()
+        # Optimization
+        opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
+
+        # Reader
+        train_reader = paddle.batch(test_merge_reader, batch_size=batch_size)
+        test_reader = paddle.batch(
+            paddle.dataset.mnist.test(), batch_size=batch_size)
+        opt.minimize(avg_cost)
+        return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
+
+
+if __name__ == "__main__":
+    runtime_main(TestDistMnist2x2)
diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_lars.py b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py
new file mode 100644
index 0000000000..977e17c37f
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py
@@ -0,0 +1,73 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import numpy as np
+import argparse
+import time
+import math
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.profiler as profiler
+from paddle.fluid import core
+import unittest
+from multiprocessing import Process
+import os
+import signal
+from functools import reduce
+from test_dist_base import TestDistRunnerBase, runtime_main
+from dist_mnist import cnn_model
+
+DTYPE = "float32"
+paddle.dataset.mnist.fetch()
+
+# Fix seed for test
+fluid.default_startup_program().random_seed = 1
+fluid.default_main_program().random_seed = 1
+
+
+class TestDistMnist2x2(TestDistRunnerBase):
+    def get_model(self, batch_size=2):
+        # Input data
+        images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
+        label = fluid.layers.data(name='label', shape=[1], dtype='int64')
+
+        # Train program
+        predict = cnn_model(images)
+        cost = fluid.layers.cross_entropy(input=predict, label=label)
+        avg_cost = fluid.layers.mean(x=cost)
+
+        # Evaluator
+        batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
+        batch_acc = fluid.layers.accuracy(
+            input=predict, label=label, total=batch_size_tensor)
+
+        inference_program = fluid.default_main_program().clone()
+        # Optimization
+        opt = fluid.optimizer.LarsMomentumOptimizer(
+            learning_rate=0.001, momentum=0.9)
+
+        # Reader
+        train_reader = paddle.batch(
+            paddle.dataset.mnist.test(), batch_size=batch_size)
+        test_reader = paddle.batch(
+            paddle.dataset.mnist.test(), batch_size=batch_size)
+        opt.minimize(avg_cost)
+        return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
+
+
+if __name__ == "__main__":
+    runtime_main(TestDistMnist2x2)
diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py
index 04924bec05..87fd03ca61 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_base.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_base.py
@@ -26,10 +26,11 @@ import argparse
 import paddle.fluid as fluid
 
 RUN_STEP = 10
+DEFAULT_BATCH_SIZE = 2
 
 
 class TestDistRunnerBase(object):
-    def get_model(self, batch_size=2):
+    def get_model(self, batch_size=DEFAULT_BATCH_SIZE):
         raise NotImplementedError(
             "get_model should be implemented by child classes.")
 
@@ -48,8 +49,7 @@ class TestDistRunnerBase(object):
         return t
 
     def run_pserver(self, args):
-
-        self.get_model(batch_size=2)
+        self.get_model(batch_size=args.batch_size)
         # NOTE: pserver should not call memory optimize
         t = self.get_transpiler(args.trainer_id,
                                 fluid.default_main_program(), args.endpoints,
@@ -65,7 +65,7 @@ class TestDistRunnerBase(object):
 
     def run_trainer(self, args):
         test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
-            self.get_model(batch_size=2)
+            self.get_model(batch_size=args.batch_size)
 
         if args.mem_opt:
             fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
@@ -92,6 +92,11 @@ class TestDistRunnerBase(object):
         strategy.allow_op_delay = False
 
         build_stra = fluid.BuildStrategy()
+        if args.batch_merge_repeat > 1:
+            pass_builder = build_stra._create_passes_from_strategy()
+            mypass = pass_builder.insert_pass(
+                len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
+            mypass.set_int("num_repeats", args.batch_merge_repeat)
 
         if args.use_reduce:
             build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
@@ -145,6 +150,9 @@ def runtime_main(test_class):
     parser.add_argument('--use_reduce', action='store_true')
     parser.add_argument(
         '--use_reader_alloc', action='store_true', required=False, default=True)
+    parser.add_argument('--batch_size', required=False, type=int, default=2)
+    parser.add_argument(
+        '--batch_merge_repeat', required=False, type=int, default=1)
 
     args = parser.parse_args()
 
@@ -244,9 +252,18 @@ class TestDistBase(unittest.TestCase):
                                  (e, retry_times))
                 retry_times -= 1
 
-    def _run_local(self, model, envs, check_error_log):
+    def _run_local(self,
+                   model,
+                   envs,
+                   check_error_log=False,
+                   batch_size=DEFAULT_BATCH_SIZE,
+                   batch_merge_repeat=1):
 
         cmd = "%s %s --role trainer" % (self._python_interp, model)
+        if batch_size != DEFAULT_BATCH_SIZE:
+            cmd += " --batch_size %d" % batch_size
+        if batch_merge_repeat > 1:
+            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
 
         if self.__use_cuda:
             cmd += " --use_cuda"
diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py
index f65dd7e2a2..922dd838f8 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py
@@ -26,6 +26,15 @@ class TestDistMnist2x2(TestDistBase):
         self.check_with_place("dist_mnist.py", delta=1e-5)
 
 
+class TestDistMnist2x2Lars(TestDistBase):
+    def _setup_config(self):
+        self._sync_mode = True
+        self._use_reduce = False
+
+    def test_se_resnext(self):
+        self.check_with_place("dist_mnist_lars.py", delta=1e-5)
+
+
 class TestDistMnist2x2WithMemopt(TestDistBase):
     def _setup_config(self):
         self._sync_mode = True
diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py
new file mode 100644
index 0000000000..22d4b79290
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py
@@ -0,0 +1,67 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import unittest
+from test_dist_base import TestDistBase
+import os
+
+
+class TestDistMnist2x2(TestDistBase):
+    def _setup_config(self):
+        self._sync_mode = True
+        self._use_reduce = False
+
+    def test_dist_train(self):
+        self.check_with_place("dist_mnist_batch_merge.py", delta=1e-5)
+
+    def check_with_place(self,
+                         model_file,
+                         delta=1e-3,
+                         check_error_log=False,
+                         need_envs={}):
+        # TODO(typhoonzero): should auto adapt GPU count on the machine.
+        required_envs = {
+            "PATH": os.getenv("PATH", ""),
+            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
+            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
+            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
+            "FLAGS_cudnn_deterministic": "1",
+        }
+
+        required_envs.update(need_envs)
+
+        if check_error_log:
+            required_envs["GLOG_v"] = "7"
+            required_envs["GLOG_logtostderr"] = "1"
+
+        no_merge_losses = self._run_local(
+            model_file,
+            required_envs,
+            check_error_log=check_error_log,
+            batch_size=4)
+
+        batch_merge_losses = self._run_local(
+            model_file,
+            required_envs,
+            check_error_log=check_error_log,
+            batch_size=2,
+            batch_merge_repeat=2)
+        # Ensure both result have values.
+        self.assertGreater(len(no_merge_losses), 1)
+        self.assertEqual(len(no_merge_losses), len(batch_merge_losses))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py
index a3d89610b4..cf4346cf2e 100644
--- a/python/paddle/fluid/tests/unittests/test_momentum_op.py
+++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py
@@ -90,6 +90,45 @@ class TestMomentumOp2(OpTest):
         self.check_output()
 
 
+class TestLarsMomentumOp(OpTest):
+    def setUp(self):
+        self.op_type = "lars_momentum"
+
+        param = np.random.random((123, 321)).astype("float32")
+        grad = np.random.random((123, 321)).astype("float32")
+        velocity = np.zeros((123, 321)).astype("float32")
+        learning_rate = np.array([0.001]).astype("float32")
+        mu = 0.0001
+        lars_coeff = 0.001
+        lars_weight_decay = 0.0005
+
+        self.inputs = {
+            'Param': param,
+            'Grad': grad,
+            'Velocity': velocity,
+            'LearningRate': learning_rate
+        }
+
+        self.attrs = {
+            'mu': mu,
+            'lars_coeff': lars_coeff,
+            'lars_weight_decay': lars_weight_decay
+        }
+
+        pnorm = np.sqrt(np.square(param).sum())
+        gnorm = np.sqrt(np.square(grad).sum())
+        local_lr = learning_rate * lars_coeff * pnorm / (
+            gnorm + lars_weight_decay * param)
+        velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay *
+                                                   param)
+        param_out = param - velocity_out
+
+        self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
+
+    def test_check_output(self):
+        self.check_output()
+
+
 class TestSparseMomentumOp(unittest.TestCase):
     def setUp(self):
         self.use_nesterov = False
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index 28d7df8e45..28ad844367 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -1431,7 +1431,7 @@ to transpile() call.")
         elif op_type == "adamax":
             if varkey in ["Moment", "InfNorm"]:
                 return param_shape
-        elif op_type == "momentum":
+        elif op_type in ["momentum", "lars_momentum"]:
             if varkey == "Velocity":
                 return param_shape
         elif op_type == "rmsprop":
@@ -1442,6 +1442,10 @@ to transpile() call.")
                 return param_shape
         elif op_type == "sgd":
             pass
+        else:
+            raise ValueError(
+                "Not supported optimizer for distributed training: %s" %
+                op_type)
         return orig_shape
 
     def _get_varname_parts(self, varname):