diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc
index 43c2eb7178..a68b69e026 100644
--- a/paddle/fluid/framework/details/build_strategy.cc
+++ b/paddle/fluid/framework/details/build_strategy.cc
@@ -18,7 +18,7 @@ limitations under the License. */
 #include <memory>
 
 #include "paddle/fluid/framework/details/memory_reuse_types.h"
-#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
+#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
 #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
 #include "paddle/fluid/framework/details/reduce_op_handle.h"
 #include "paddle/fluid/framework/details/sequential_execution_pass.h"
@@ -86,10 +86,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
     if (strategy.memory_optimize_) {
       auto analysis_var_pass = AppendPass("analysis_var_pass");
     }
-    // Convert graph to run on multi-devices.
-    auto multi_devices_pass = AppendPass("multi_devices_pass");
-    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
-                                                         &strategy_);
+
+    AppendMultiDevPass(strategy);
 
     // Add a graph print pass to record a graph with device info.
     if (!strategy_.debug_graphviz_path_.empty()) {
@@ -115,6 +113,25 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
     }
   }
 
+  // Convert graph to run on multi-devices.
+  void AppendMultiDevPass(const BuildStrategy &strategy) {
+    ir::Pass *multi_devices_pass;
+    if (strategy_.is_distribution_) {
+      multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
+    } else {
+      if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
+        multi_devices_pass =
+            AppendPass("allreduce_mode_multi_devices_pass").get();
+      } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
+        multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
+      } else {
+        PADDLE_THROW("Unknown reduce strategy.");
+      }
+    }
+    multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
+                                                         &strategy_);
+  }
+
  private:
   BuildStrategy strategy_;
 };
@@ -131,6 +148,10 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
   return pass_builder_;
 }
 
+bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
+  return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
+}
+
 std::unique_ptr<ir::Graph> BuildStrategy::Apply(
     const ProgramDesc &main_program, const std::vector<platform::Place> &places,
     const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
@@ -145,22 +166,23 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
 
   std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
   for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
-    if (pass->Type() == "multi_devices_pass") {
-      pass->Erase("places");
-      pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
-      pass->Erase("loss_var_name");
-      pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
-      pass->Erase("local_scopes");
-      pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
+    if (IsMultiDevPass(pass->Type())) {
+      pass->Erase(kPlaces);
+      pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
+      pass->Erase(kLossVarName);
+      pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
+      pass->Erase(kLocalScopes);
+      pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
                                                     &local_scopes);
-      pass->Erase("nranks");
-      pass->Set<size_t>("nranks", new size_t(nranks));
+      pass->Erase(kNRanks);
+      pass->Set<size_t>(kNRanks, new size_t(nranks));
 
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
       platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
       pass->Erase("nccl_ctxs");
       pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
 #endif
+
     } else if (pass->Type() == "analysis_var_pass") {
       const std::vector<OpDesc *> *all_op_descs =
           new std::vector<OpDesc *>(main_program.Block(0).AllOps());
@@ -201,7 +223,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
 USE_PASS(fuse_elewise_add_act_pass);
 USE_PASS(graph_viz_pass);
 USE_PASS(multi_batch_merge_pass);
-USE_PASS(multi_devices_pass);
+USE_PASS(reduce_mode_multi_devices_pass);
+USE_PASS(allreduce_mode_multi_devices_pass);
+USE_PASS(dist_multi_devices_pass);
 USE_PASS(multi_devices_check_pass);
 USE_PASS(multi_devices_print_pass);
 USE_PASS(analysis_var_pass);
diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h
index b75c01c485..15c2e01b61 100644
--- a/paddle/fluid/framework/details/build_strategy.h
+++ b/paddle/fluid/framework/details/build_strategy.h
@@ -74,8 +74,6 @@ struct BuildStrategy {
 
   bool fuse_elewise_add_act_ops_{false};
 
-  bool enable_data_balance_{false};
-
   bool memory_optimize_{false};
 
   bool memory_early_delete_{false};
@@ -84,6 +82,10 @@ struct BuildStrategy {
 
   bool fuse_broadcast_op_{false};
 
+  // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
+  // num_trainers is 1, so the current fields of build_strategy doesn't tell if
+  // it's distributed model.
+  bool is_distribution_{false};
   int num_trainers_{1};
   int trainer_id_{0};
   std::vector<std::string> trainers_endpoints_;
@@ -104,6 +106,8 @@ struct BuildStrategy {
 
   bool IsFinalized() const { return is_finalized_; }
 
+  bool IsMultiDevPass(const std::string &pass_name) const;
+
   // Apply the passes built by the pass_builder_. The passes will be
   // applied to the Program and output an ir::Graph.
   std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program,
diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
index c8ea188046..a4bb1e26d9 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
 #include <string>
+#include "paddle/fluid/framework/details/multi_devices_helper.h"
 #include "paddle/fluid/framework/ir/graph.h"
 #include "paddle/fluid/framework/ir/graph_helper.h"
 
@@ -21,68 +21,78 @@ namespace paddle {
 namespace framework {
 namespace details {
 
-bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
-  std::unordered_map<OpHandleBase *, size_t> pending_ops;
-  std::unordered_set<VarHandleBase *> pending_vars;
-  std::unordered_set<VarHandleBase *> ready_vars;
-  std::unordered_set<OpHandleBase *> ready_ops;
+class SSAGraghBuilderWithChecker : public ir::Pass {
+ protected:
+  std::unique_ptr<ir::Graph> ApplyImpl(
+      std::unique_ptr<ir::Graph> graph) const override {
+    PADDLE_ENFORCE(IsValidGraph(graph.get()));
+    return graph;
+  }
 
-  auto insert_pending_var = [&](VarHandleBase *var) {
-    pending_vars.insert(var);
-    if (var->GeneratedOp() == nullptr) {
-      ready_vars.emplace(var);
-    }
-  };
+  bool IsValidGraph(const ir::Graph *graph) const {
+    std::unordered_map<OpHandleBase *, size_t> pending_ops;
+    std::unordered_set<VarHandleBase *> pending_vars;
+    std::unordered_set<VarHandleBase *> ready_vars;
+    std::unordered_set<OpHandleBase *> ready_ops;
 
-  for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
-    for (auto &name_pair : var_map) {
-      for (auto &version_pair : name_pair.second) {
-        insert_pending_var(version_pair);
+    auto insert_pending_var = [&](VarHandleBase *var) {
+      pending_vars.insert(var);
+      if (var->GeneratedOp() == nullptr) {
+        ready_vars.emplace(var);
       }
-    }
-  }
+    };
 
-  for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
-    insert_pending_var(var);
-  }
+    for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
+      for (auto &name_pair : var_map) {
+        for (auto &version_pair : name_pair.second) {
+          insert_pending_var(version_pair);
+        }
+      }
+    }
 
-  for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
-    if (op->Inputs().empty()) {
-      ready_ops.insert(op);
-    } else {
-      pending_ops.insert({op, op->NoDupInputSize()});
+    for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
+      insert_pending_var(var);
     }
-  }
 
-  auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
-    for (auto *op : set) {
-      for (auto out : op->Outputs()) {
-        ready_vars.emplace(out);
+    for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
+      if (op->Inputs().empty()) {
+        ready_ops.insert(op);
+      } else {
+        pending_ops.insert({op, op->NoDupInputSize()});
       }
     }
-    set.clear();
-  };
 
-  while (!pending_vars.empty()) {
-    run_all_ops(ready_ops);
+    auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
+      for (auto *op : set) {
+        for (auto out : op->Outputs()) {
+          ready_vars.emplace(out);
+        }
+      }
+      set.clear();
+    };
 
-    if (ready_vars.empty()) {
-      return false;
-    }
+    while (!pending_vars.empty()) {
+      run_all_ops(ready_ops);
 
-    for (auto ready_var : ready_vars) {
-      pending_vars.erase(ready_var);
-      for (auto *op : ready_var->PendingOps()) {
-        auto &deps = --pending_ops[op];
-        if (deps == 0) {
-          ready_ops.insert(op);
+      if (ready_vars.empty()) {
+        return false;
+      }
+
+      for (auto ready_var : ready_vars) {
+        pending_vars.erase(ready_var);
+        for (auto *op : ready_var->PendingOps()) {
+          auto &deps = --pending_ops[op];
+          if (deps == 0) {
+            ready_ops.insert(op);
+          }
         }
       }
+      ready_vars.clear();
     }
-    ready_vars.clear();
+    return true;
   }
-  return true;
-}
+};
+
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.h b/paddle/fluid/framework/details/multi_devices_graph_check_pass.h
deleted file mode 100644
index 1e2b1867c3..0000000000
--- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#pragma once
-
-#include "paddle/fluid/framework/details/multi_devices_helper.h"
-
-#include <string>
-
-namespace paddle {
-namespace framework {
-namespace details {
-
-class SSAGraghBuilderWithChecker : public ir::Pass {
- protected:
-  std::unique_ptr<ir::Graph> ApplyImpl(
-      std::unique_ptr<ir::Graph> graph) const override {
-    PADDLE_ENFORCE(IsValidGraph(graph.get()));
-    return graph;
-  }
-
-  bool IsValidGraph(const ir::Graph* graph) const;
-};
-
-}  // namespace details
-}  // namespace framework
-}  // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
index 761c9ab904..d91993bd4f 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
@@ -134,15 +134,8 @@ void AddOutputToLeafOps(ir::Graph *graph) {
 }
 }  // namespace
 
-static const char kLossVarName[] = "loss_var_name";
-static const char kPlaces[] = "places";
-static const char kLocalScopes[] = "local_scopes";
-static const char kStrategy[] = "strategy";
-static const char kNRanks[] = "nranks";
-
-void MultiDevSSAGraphBuilder::Init() const {
+void MultiDevSSAGraphBuilderBase::Init() const {
   all_vars_.clear();
-  balance_vars_.clear();
 
   loss_var_name_ = Get<const std::string>(kLossVarName);
   places_ = Get<const std::vector<platform::Place>>(kPlaces);
@@ -151,31 +144,16 @@ void MultiDevSSAGraphBuilder::Init() const {
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
   nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
 #endif
-
-  balance_vars_.resize(places_.size(), 0);
-
-  if (strategy_.enable_data_balance_ && places_.size() == 1) {
-    LOG(WARNING) << "It is no need to enable data balance when there is only "
-                    "one place. enable_data_balance is set to False.";
-    strategy_.enable_data_balance_ = false;
-  }
 }
 
-std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
+std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   Init();
-  // Give the topology sort order and rebuild the graph structure.
-  std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(*graph);
-
-  if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
-    sorted_ops = SortForReduceMode(sorted_ops);
-  }
+  std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
 
   auto nodes = graph->ReleaseNodes();
   ir::Graph &result = *graph;
 
-  size_t nranks = Get<size_t>(kNRanks);
-
   for (auto &node : nodes) {
     if (node->IsVar() && node->Var()) {
       all_vars_.emplace(node->Name(), node->Var());
@@ -187,146 +165,61 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
   result.Set(kGraphDepVars, new GraphDepVars);
   result.Set(kGraphOps, new GraphOps);
 
-  std::vector<std::unordered_set<std::string>> bcast_var_name_set;
-  bcast_var_name_set.resize(places_.size());
-
   bool is_forwarding = true;
-  bool is_dist_train = false;
-
-  std::unordered_map<std::string, int> sharded_var_device;
+  bool insert_collection_ops = NeedCollectiveOps();
 
   for (ir::Node *node : sorted_ops) {
-    if (OpHaveRole(*node, OpRole::kRPC)) {
-      int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
-      PADDLE_ENFORCE(op_dev_id != -1,
-                     "Can not schedule the RPC operator to the right place.");
-      if (node->Op()->Type() == "recv") {
-        auto recv_vars_attr =
-            boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
-                OpProtoAndCheckerMaker::OpRoleVarAttrName()));
-        PADDLE_ENFORCE(recv_vars_attr.size() == 2UL);  // [parameter, gradient]
-        if (recv_vars_attr[0].find(".block") == std::string::npos) {
-          bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]);
-        }
-      }
-      is_dist_train = true;
-    } else if (OpHaveRole(*node, OpRole::kDist)) {
-      int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
-      if (node->Op()->Type() == "concat") {
-        auto origin_param_name = node->Op()->OutputArgumentNames()[0];
-        bcast_var_name_set[op_dev_id].emplace(origin_param_name);
-      }
-    } else if (IsScaleLossOp(node)) {
-      // user can customize loss@grad if not use_default_grad_scale_
-      if (strategy_.gradient_scale_ !=
-          BuildStrategy::GradientScaleStrategy::kCustomized) {
-        // TODO(paddle-dev): Why is there no input for this op_handle?
-        auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
-        auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType();
-        CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
-                              out_dtype);
-      }
-      // This assumes the backward generating code will ensure IsScaleLossOp
-      // is true only for the op that scale the final scalar loss.
-      // It also assumes backward op will always follow the forward op in
-      // the block.
-      is_forwarding = false;
+    if (DealWithSpecialOp(&result, node)) {
+      continue;
     } else {
-      int op_dev_id = GetOpDeviceID(node, sharded_var_device);
-      if (op_dev_id != -1) {  // This op only runs on one specific device.
-        CreateComputationalOp(&result, node, op_dev_id);
-        for (ir::Node *n : node->outputs) {
-          sharded_var_device.emplace(n->Name(), op_dev_id);
-        }
+      // This op runs on all devices
+      if (IsScaleLossOp(node)) {
+        // user can customize loss@grad if not use_default_grad_scale_
+        InsertScaleLossGradOp(&result, node);
+        // This assumes the backward generating code will ensure IsScaleLossOp
+        // is true only for the op that scale the final scalar loss.
+        // It also assumes backward op will always follow the forward op in
+        // the block.
+        is_forwarding = false;
       } else {
-        // This op runs on all devices, and its output may have parameter's
-        // gradients.
-        // TODO(paddle-dev): Why is so special about "read" op?
-        if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
-          node->Op()->SetAttr("throw_eof_exp", false);
-          CreateComputationalOps(&result, node, places_.size());
-          const auto &data_var_names = node->Op()->Output("Out");
-          InsertDataBalanceOp(&result, data_var_names);
-        } else {
-          CreateComputationalOps(&result, node, places_.size());
-        }
+        CreateComputationalOps(&result, node, places_.size());
+      }
 
-        if (!is_forwarding && nranks > 1UL) {
+      // Insert collection ops
+      if (!is_forwarding && insert_collection_ops) {
+        try {
           bool is_bk_op =
               static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
                                     OpProtoAndCheckerMaker::OpRoleAttrName())) &
                                 static_cast<int>(OpRole::kBackward));
           if (!is_bk_op) continue;
+
           // Currently, we assume that once gradient is generated, it can be
           // broadcast, and each gradient is only broadcast once.
-          try {
-            auto backward_vars = boost::get<std::vector<std::string>>(
-                node->Op()->GetNullableAttr(
-                    OpProtoAndCheckerMaker::OpRoleVarAttrName()));
-
-            PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
-
-            for (size_t i = 0; i < backward_vars.size(); i += 2) {
-              auto &p_name = backward_vars[i];
-              auto &g_name = backward_vars[i + 1];
-              VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
-              size_t cur_device_id = -1;
-              switch (strategy_.reduce_) {
-                case BuildStrategy::ReduceStrategy::kReduce:
-                  cur_device_id = GetAppropriateDeviceID({g_name});
-                  CreateReduceOp(&result, g_name, cur_device_id);
-                  sharded_var_device.emplace(g_name, cur_device_id);
-                  if (!is_dist_train) {
-                    bcast_var_name_set[cur_device_id].emplace(p_name);
-                  }
-                  break;
-                case BuildStrategy::ReduceStrategy::kAllReduce:
-                  if (IsSparseGradient(g_name)) {
-                    CreateReduceOp(&result, g_name, 0);
-                    CreateBroadcastOp(&result, g_name, 0);
-                  } else {
-                    InsertAllReduceOp(&result, g_name);
-                  }
-                  break;
-                default:
-                  LOG(FATAL) << "Unknown reduce strategy ";
-                  break;
-              }
-            }
-          } catch (boost::bad_get e) {
+          auto backward_vars =
+              boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
+                  OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+          PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
+
+          for (size_t i = 0; i < backward_vars.size(); i += 2) {
+            auto &p_name = backward_vars[i];
+            auto &g_name = backward_vars[i + 1];
+            VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
+
+            InsertCollectiveOp(&result, p_name, g_name);
           }
+        } catch (boost::bad_get e) {
         }
       }
     }
   }
-  bool use_gpu = false;
-#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
-  use_gpu = nccl_ctxs_ != nullptr;
-#endif
 
-  // Insert broadcast operators principle:
-  // 1. Broadcast optimized parameters in Reduce strategy;
-  // 2. No need broadcast optimized parameters in AllReduce strategy because of
-  //    the optimization sub-graph would be run on every GPU;
-  // 3. Allways broadcast received parameters in Distribute Training.
-  if ((use_gpu &&
-       strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
-      is_dist_train) {
-    if (strategy_.fuse_broadcast_op_) {
-      CreateFusedBroadcastOp(&result, bcast_var_name_set);
-    } else {
-      for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
-        auto &to_bcast_set = bcast_var_name_set[dev_id];
-        for (auto &bcast_name : to_bcast_set) {
-          CreateBroadcastOp(&result, bcast_name, dev_id);
-        }
-      }
-    }
-  }
+  InsertPostprocessOps(&result);
+
   /*
   Dependency graph has been constructed. However, there are still data
   hazards need to be handled.
- */
+  */
   PolishGraphToSupportDataHazards(&result);
 
   /*
@@ -337,67 +230,54 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
   return graph;
 }
 
-std::vector<ir::Node *> MultiDevSSAGraphBuilder::SortForReduceMode(
-    const std::vector<ir::Node *> &topo_ops) const {
-  std::unordered_map<std::string, int> sharded_var_device;
-  std::vector<ir::Node *> sorted_ops;
-  std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op;
-  sorted_ops.reserve(topo_ops.size());
-
-  auto insert_delayed_op = [&](const std::string &var_name, int dev_id) {
-    sharded_var_device.emplace(var_name, dev_id);
-    if (delayed_op.count(var_name)) {
-      auto &ops = delayed_op.at(var_name);
-      sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end());
-      delayed_op.at(var_name).clear();
-    }
-  };
+void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
+    ir::Graph *result, const ir::Node *node) const {
+  // user can customize loss@grad if not use_default_grad_scale_
+  size_t loss_scale = 0;
+  switch (this->strategy_.gradient_scale_) {
+    case BuildStrategy::GradientScaleStrategy::kOne:
+      loss_scale = 1;
+      break;
+    case BuildStrategy::GradientScaleStrategy::kCoeffNumDevice:
+      loss_scale = Get<size_t>(kNRanks);
+      break;
+    case BuildStrategy::GradientScaleStrategy::kCustomized:
+      loss_scale = 0;
+      break;
+    default:
+      LOG(FATAL) << "Unknown gradient scale strategy.";
+      break;
+  }
+
+  if (loss_scale) {
+    // TODO(paddle-dev): Why is there no input for this op_handle?
+    auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
+    auto out_dtype = this->all_vars_.at(loss_grad_name)->GetDataType();
+    this->CreateScaleLossGradOp(result, loss_grad_name, node->outputs[0],
+                                loss_scale, out_dtype);
+  }
+}
 
-  for (ir::Node *node : topo_ops) {
-    int op_dev_id = GetOpDeviceID(node, sharded_var_device, &delayed_op);
-    if (op_dev_id > -1) {
-      // This op only runs on one specific device.
-      sorted_ops.emplace_back(node);
-      for (ir::Node *n : node->outputs) {
-        insert_delayed_op(n->Name(), op_dev_id);
-      }
-    } else if (op_dev_id == -1) {
-      // This op runs on all devices, and its output may have parameter's
-      // gradients.
-      sorted_ops.emplace_back(node);
-      bool is_bk_op =
-          static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
-                                OpProtoAndCheckerMaker::OpRoleAttrName())) &
-                            static_cast<int>(OpRole::kBackward));
-      if (!is_bk_op) continue;
-      // Currently, we assume that once gradient is generated, it can be
-      // broadcast, and each gradient is only broadcast once.
-      std::vector<std::string> backward_vars;
-      try {
-        backward_vars =
-            boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
-                OpProtoAndCheckerMaker::OpRoleVarAttrName()));
-      } catch (boost::bad_get e) {
-      }
-      PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
+std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
+    const ir::Graph &graph) const {
+  return ir::TopologySortOperations(graph);
+}
 
-      for (size_t i = 0; i < backward_vars.size(); i += 2) {
-        auto &g_name = backward_vars[i + 1];
-        size_t cur_device_id = GetAppropriateDeviceID({g_name});
-        insert_delayed_op(g_name, static_cast<int>(cur_device_id));
-      }
-    } else if (op_dev_id == -2) {
-      // The Op on which the Op depends has not yet been generated.
-    }
-  }
+bool MultiDevSSAGraphBuilderBase::UseGPU() const {
+  bool use_gpu = false;
+#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
+  use_gpu = nccl_ctxs_ != nullptr;
+#endif
+  return use_gpu;
+}
 
-  PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
-  return sorted_ops;
+bool MultiDevSSAGraphBuilderBase::NeedCollectiveOps() const {
+  return Get<size_t>(kNRanks) > 1;
 }
 
-void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
-                                                ir::Node *node,
-                                                size_t place_id) const {
+void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
+                                                    ir::Node *node,
+                                                    size_t place_id) const {
   auto p = places_[place_id];
   auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
   op_handle->SetDeviceContext(p,
@@ -420,28 +300,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
   }
 }
 
-size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
-    const std::vector<std::string> &var_names) const {
-  int64_t numel_sum = 0;
-  for (auto var_name : var_names) {
-    if (all_vars_.find(var_name) == all_vars_.end()) continue;
-    auto var_desc = all_vars_.at(var_name);
-    PADDLE_ENFORCE_NOT_NULL(var_desc);
-    auto dim = framework::make_ddim(var_desc->GetShape());
-    int64_t numel = framework::product(dim);
-    PADDLE_ENFORCE_GT(numel, 0);
-    numel_sum += numel;
-  }
-
-  auto smallest =
-      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
-  size_t dev_id =
-      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
-  balance_vars_[dev_id] += numel_sum;
-  return dev_id;
-}
-
-void MultiDevSSAGraphBuilder::SetCommunicationContext(
+void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
     OpHandleBase *op_handle, const platform::Place &p) const {
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
   if (nccl_ctxs_ == nullptr) {
@@ -454,9 +313,9 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
 #endif
 }
 
-void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
-                                                const std::string &p_name,
-                                                size_t src_dev_id) const {
+void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
+                                                    const std::string &p_name,
+                                                    size_t src_dev_id) const {
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
   auto *op_handle = new BroadcastOpHandle(
       result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
@@ -484,7 +343,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
   }
 }
 
-void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
+void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
     ir::Graph *result,
     const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
@@ -522,17 +381,17 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
   }
 }
 
-void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
-                                                    ir::Node *node,
-                                                    int dev_id) const {
+void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
+                                                        ir::Node *node,
+                                                        int dev_id) const {
   result->Get<GraphOps>(kGraphOps).emplace_back(
       new ComputationOpHandle(result->CreateOpNode(node->Op()),
                               local_scopes_[dev_id], places_[dev_id], dev_id));
   CreateOpHandleIOs(result, node, dev_id);
 }
 
-void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
-                                                const std::string &og) const {
+void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
+    ir::Graph *result, const std::string &og) const {
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
   result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
       result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
@@ -560,102 +419,15 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
   }
 }
 
-void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
-    ir::Graph *result, const std::vector<std::string> &datas) const {
-#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
-  result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
-      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
-      local_scopes_, places_, nccl_ctxs_));
-#else
-  result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
-      result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
-      local_scopes_, places_));
-#endif
-  auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
-  for (size_t i = 0; i < places_.size(); ++i) {
-    auto &p = places_[i];
-    SetCommunicationContext(op_handle, p);
-    for (const std::string &d_name : datas) {
-      auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
-      PADDLE_ENFORCE(!vars.empty());
-      op_handle->AddInput(vars.back());
-      auto var = new VarHandle(
-          result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
-          vars.size(), i, d_name, p);
-      vars.emplace_back(var);
-      op_handle->AddOutput(var);
-    }
-  }
-}
-
-int MultiDevSSAGraphBuilder::GetOpDeviceID(
-    ir::Node *node,
-    const std::unordered_map<std::string, int> &sharded_var_device,
-    std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
-  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
-    return -1;
-  }
-
-  if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
-    return -1;
-  }
-
-  auto param_grad = boost::get<std::vector<std::string>>(
-      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
-
-  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
-  int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
-
-  if (dev_id == -1) {
-    (*delay_ops)[param_grad[1]].push_back(node);
-    return -2;
-  }
-  return dev_id;
-}
-
-int MultiDevSSAGraphBuilder::GetOpDeviceID(
-    ir::Node *node,
-    const std::unordered_map<std::string, int> &sharded_var_device) const {
-  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
-    return -1;
-  }
-
-  if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
-    return -1;
-  }
-  auto param_grad = boost::get<std::vector<std::string>>(
-      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
-
-  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
-  int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
-  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
-                    node->Op()->Type(), param_grad[0], param_grad[1]);
-  return dev_id;
-}
-
-int MultiDevSSAGraphBuilder::GetVarDeviceID(
-    const std::string &varname,
-    const std::unordered_map<std::string, int> &sharded_var_device) const {
-  auto got = sharded_var_device.find(varname);
-  if (got == sharded_var_device.end()) {
-    auto pos = varname.find(framework::kNewGradSuffix);
-    if (pos != std::string::npos) {
-      got = sharded_var_device.find(varname.substr(0, pos));
-    }
-  }
-  return got == sharded_var_device.end() ? -1 : got->second;
-}
-
-void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
+void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(
     ir::Graph *result, const std::string &loss_grad_name,
-    ir::Node *out_var_node, proto::VarType::Type dtype) const {
-  size_t nranks = Get<size_t>("nranks");
+    ir::Node *out_var_node, size_t loss_scale,
+    proto::VarType::Type dtype) const {
   for (size_t i = 0; i < places_.size(); ++i) {
-    // Insert ScaleCost OpHandle
     auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
     auto *op_handle = new ScaleLossGradOpHandle(
         result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
-        nranks, local_scopes_[i], places_[i], dev_ctx, dtype);
+        loss_scale, local_scopes_[i], places_[i], dev_ctx, dtype);
     result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
 
     // FIXME: Currently ScaleLossGradOp only use device_count as scale
@@ -669,9 +441,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
   }
 }
 
-void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
-                                                     ir::Node *node,
-                                                     size_t num_places) const {
+void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
+    ir::Graph *result, ir::Node *node, size_t num_places) const {
   for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
     auto p = places_[scope_idx];
     auto s = local_scopes_[scope_idx];
@@ -681,9 +452,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
   }
 }
 
-VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
-                                                   const std::string &og,
-                                                   int dst_dev_id) const {
+VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
+                                                       const std::string &og,
+                                                       int dst_dev_id) const {
 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
   result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
       result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
@@ -712,51 +483,273 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
   return var;
 }
 
-int MultiDevSSAGraphBuilder::CreateDistTrainOp(
-    ir::Graph *result, ir::Node *node,
-    std::unordered_map<std::string, int> *sharded_var_device) const {
-  int op_dev_id = -1;
-  std::vector<std::string> input_var_names;
-  std::vector<std::string> output_var_names;
-  for (ir::Node *input : node->inputs) {
-    input_var_names.push_back(input->Name());
+bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const {
+  return boost::get<int>(
+             node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
+             (static_cast<int>(OpRole::kBackward) |
+              static_cast<int>(OpRole::kLoss)) &&
+         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
+}
+
+bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
+    const std::string &og) const {
+  PADDLE_ENFORCE(all_vars_.count(og) != 0);
+  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
+    return true;
   }
-  for (ir::Node *output : node->outputs) {
-    output_var_names.push_back(output->Name());
+  return false;
+}
+
+void AllReduceSSAGraphBuilder::InsertCollectiveOp(
+    ir::Graph *result, const std::string &p_name,
+    const std::string &g_name) const {
+  if (IsSparseGradient(g_name)) {
+    CreateReduceOp(result, g_name, 0);
+    CreateBroadcastOp(result, g_name, 0);
+  } else {
+    CreateAllReduceOp(result, g_name);
   }
+}
 
-  if (node->Op()->Type() == "split_byref" ||
-      node->Op()->Type() == "split_selected_rows" ||
-      node->Op()->Type() == "split_ids") {
-    // TODO(paddle-dev): getting the first var is not safe.
-    op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device);
-    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
-      op_dev_id = GetAppropriateDeviceID(input_var_names);
-      for (auto &varname : input_var_names) {
-        sharded_var_device->emplace(varname, op_dev_id);
+int BalanceVarSSAGraphBuilder::GetVarDeviceID(
+    const std::string &varname) const {
+  auto got = sharded_var_device_.find(varname);
+  if (got == sharded_var_device_.end()) {
+    auto pos = varname.find(framework::kNewGradSuffix);
+    if (pos != std::string::npos) {
+      got = sharded_var_device_.find(varname.substr(0, pos));
+    }
+  }
+  return got == sharded_var_device_.end() ? -1 : got->second;
+}
+
+int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
+  if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
+    return -1;
+  }
+  if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
+    return -1;
+  }
+  auto param_grad = boost::get<std::vector<std::string>>(
+      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+
+  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
+  int dev_id = GetVarDeviceID(param_grad[1]);
+  PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
+                    node->Op()->Type(), param_grad[0], param_grad[1]);
+  return dev_id;
+}
+
+size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID(
+    const std::vector<std::string> &var_names) const {
+  int64_t numel_sum = 0;
+  for (auto var_name : var_names) {
+    if (all_vars_.find(var_name) == all_vars_.end()) continue;
+    auto var_desc = all_vars_.at(var_name);
+    PADDLE_ENFORCE_NOT_NULL(var_desc);
+    auto dim = framework::make_ddim(var_desc->GetShape());
+    int64_t numel = framework::product(dim);
+    PADDLE_ENFORCE_GT(numel, 0);
+    numel_sum += numel;
+  }
+
+  auto smallest =
+      std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
+  size_t dev_id =
+      static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
+  balance_vars_[dev_id] += numel_sum;
+  return dev_id;
+}
+
+void BalanceVarSSAGraphBuilder::ResetState() const {
+  balance_vars_.clear();
+  sharded_var_device_.clear();
+
+  balance_vars_.resize(places_.size(), 0);
+}
+
+void ReduceSSAGraphBuilder::Init() const {
+  MultiDevSSAGraphBuilderBase::Init();
+  ResetState();
+}
+
+void ReduceSSAGraphBuilder::ResetState() const {
+  BalanceVarSSAGraphBuilder::ResetState();
+  bcast_var_name_set_.clear();
+  bcast_var_name_set_.resize(places_.size());
+}
+
+void ReduceSSAGraphBuilder::InsertCollectiveOp(
+    ir::Graph *result, const std::string &p_name,
+    const std::string &g_name) const {
+  size_t cur_device_id = GetAppropriateDeviceID({g_name});
+  CreateReduceOp(result, g_name, cur_device_id);
+  sharded_var_device_.emplace(g_name, cur_device_id);
+  bcast_var_name_set_[cur_device_id].emplace(p_name);
+}
+
+bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
+                                              ir::Node *node) const {
+  int op_dev_id = BalanceVarSSAGraphBuilder::GetOpDeviceID(node);
+  if (op_dev_id != -1) {
+    // This op only runs on one specific device.
+    CreateComputationalOp(result, node, op_dev_id);
+    for (ir::Node *n : node->outputs) {
+      sharded_var_device_.emplace(n->Name(), op_dev_id);
+    }
+    return true;
+  }
+  return false;
+}
+
+void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
+  if (UseGPU()) {
+    if (strategy_.fuse_broadcast_op_) {
+      CreateFusedBroadcastOp(result, bcast_var_name_set_);
+    } else {
+      for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
+        auto &to_bcast_set = bcast_var_name_set_[dev_id];
+        for (auto &bcast_name : to_bcast_set) {
+          CreateBroadcastOp(result, bcast_name, dev_id);
+        }
       }
     }
-    for (auto &varname : output_var_names) {
-      sharded_var_device->emplace(varname, op_dev_id);
+  }
+}
+
+int ReduceSSAGraphBuilder::GetOpDeviceID(
+    ir::Node *node,
+    std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
+  if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
+    return -1;
+  }
+
+  auto param_grad = boost::get<std::vector<std::string>>(
+      node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+
+  PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
+  int dev_id = GetVarDeviceID(param_grad[1]);
+
+  if (dev_id == -1) {
+    (*delay_ops)[param_grad[1]].push_back(node);
+    return -2;
+  }
+  return dev_id;
+}
+
+std::vector<ir::Node *> ReduceSSAGraphBuilder::SortOperations(
+    const ir::Graph &graph) const {
+  std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(graph);
+  return SortForReduceMode(sorted_ops);
+}
+
+std::vector<ir::Node *> ReduceSSAGraphBuilder::SortForReduceMode(
+    const std::vector<ir::Node *> &topo_ops) const {
+  std::vector<ir::Node *> sorted_ops;
+  std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op;
+  sorted_ops.reserve(topo_ops.size());
+  ResetState();
+
+  auto insert_delayed_op = [&](const std::string &var_name, int dev_id) {
+    sharded_var_device_.emplace(var_name, dev_id);
+    if (delayed_op.count(var_name)) {
+      auto &ops = delayed_op.at(var_name);
+      sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end());
+      delayed_op.at(var_name).clear();
     }
-  } else if (node->Op()->Type() == "concat") {
-    op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device);
-    for (auto &varname : output_var_names) {
-      sharded_var_device->emplace(varname, op_dev_id);
+  };
+
+  for (ir::Node *node : topo_ops) {
+    int op_dev_id = GetOpDeviceID(node, &delayed_op);
+    if (op_dev_id > -1) {
+      // This op only runs on one specific device.
+      sorted_ops.emplace_back(node);
+      for (ir::Node *n : node->outputs) {
+        insert_delayed_op(n->Name(), op_dev_id);
+      }
+    } else if (op_dev_id == -1) {
+      // This op runs on all devices, and its output may have parameter's
+      // gradients.
+      sorted_ops.emplace_back(node);
+      bool is_bk_op =
+          static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
+                                OpProtoAndCheckerMaker::OpRoleAttrName())) &
+                            static_cast<int>(OpRole::kBackward));
+      if (!is_bk_op) continue;
+      // Currently, we assume that once gradient is generated, it can be
+      // broadcast, and each gradient is only broadcast once.
+      std::vector<std::string> backward_vars;
+      try {
+        backward_vars =
+            boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
+                OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+      } catch (boost::bad_get e) {
+      }
+      PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
+
+      for (size_t i = 0; i < backward_vars.size(); i += 2) {
+        auto &g_name = backward_vars[i + 1];
+        size_t cur_device_id = GetAppropriateDeviceID({g_name});
+        insert_delayed_op(g_name, static_cast<int>(cur_device_id));
+      }
+    } else if (op_dev_id == -2) {
+      // The Op on which the Op depends has not yet been generated.
     }
-  } else {
-    LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
-    PADDLE_THROW(
-        "the distribute training related op should be in [split_byref, "
-        "concat].");
   }
 
-  PADDLE_ENFORCE(op_dev_id != -1,
-                 "can not find right place for distributed op: %s",
-                 node->Op()->Type());
+  PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
 
-  CreateComputationalOp(result, node, op_dev_id);
-  return op_dev_id;
+  ResetState();
+  return sorted_ops;
+}
+
+void DistSSAGraphBuilder::Init() const {
+  MultiDevSSAGraphBuilderBase::Init();
+  ResetState();
+}
+
+void DistSSAGraphBuilder::ResetState() const {
+  BalanceVarSSAGraphBuilder::ResetState();
+  bcast_var_name_set_.clear();
+  bcast_var_name_set_.resize(places_.size());
+}
+
+bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
+                                            ir::Node *node) const {
+  bool insert_op = false;
+  if (OpHaveRole(*node, OpRole::kRPC)) {
+    int op_dev_id = CreateRPCOp(result, node);
+    PADDLE_ENFORCE(op_dev_id != -1,
+                   "Can not schedule the RPC operator to the right place.");
+    if (node->Op()->Type() == "recv") {
+      auto recv_vars_attr =
+          boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
+              OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+      PADDLE_ENFORCE(recv_vars_attr.size() == 2UL);  // [parameter, gradient]
+      if (recv_vars_attr[0].find(".block") == std::string::npos) {
+        bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]);
+      }
+    }
+    insert_op = true;
+    need_broadcast_var_ = true;
+  } else if (OpHaveRole(*node, OpRole::kDist)) {
+    int op_dev_id = CreateDistTrainOp(result, node);
+    if (node->Op()->Type() == "concat") {
+      auto origin_param_name = node->Op()->OutputArgumentNames()[0];
+      bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
+    }
+    insert_op = true;
+  } else {
+    int op_dev_id = GetOpDeviceID(node);
+    if (op_dev_id != -1) {  // This op only runs on one specific device.
+      CreateComputationalOp(result, node, op_dev_id);
+      for (ir::Node *n : node->outputs) {
+        sharded_var_device_.emplace(n->Name(), op_dev_id);
+      }
+      insert_op = true;
+    }
+  }
+  return insert_op;
 }
 
 void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
@@ -775,13 +768,11 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
 }
 
 // Create RPC related op handles that connects its in ops and out ops.
-int MultiDevSSAGraphBuilder::CreateRPCOp(
-    ir::Graph *result, ir::Node *node,
-    std::unordered_map<std::string, int> *sharded_var_device) const {
+int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
   int op_dev_id = -1;
   if (node->Op()->Type() == "send") {
     // TODO(paddle-dev): getting the first var is not safe.
-    op_dev_id = GetVarDeviceID(node->inputs[0]->Name(), *sharded_var_device);
+    op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
     PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
                    "This hack no longer holds, please fix.");
     // the variable name which contains .block means it was splited by
@@ -799,9 +790,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
       VLOG(10) << "send grad " << input_var_names[0] << " origin "
                << send_param_grad[1] << " place: " << op_dev_id;
       for (auto &varname : input_var_names) {
-        sharded_var_device->emplace(varname, op_dev_id);
+        sharded_var_device_.emplace(varname, op_dev_id);
       }
-      sharded_var_device->emplace(send_param_grad[1], op_dev_id);
+      sharded_var_device_.emplace(send_param_grad[1], op_dev_id);
     }
   } else if (node->Op()->Type() == "recv") {
     std::vector<std::string> output_var_names;
@@ -811,7 +802,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
     auto recv_param_grad = boost::get<std::vector<std::string>>(
         node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
     if (recv_param_grad.size() == 2U) {
-      op_dev_id = GetVarDeviceID(recv_param_grad[1], *sharded_var_device);
+      op_dev_id = GetVarDeviceID(recv_param_grad[1]);
       VLOG(10) << "recv param " << recv_param_grad[0]
                << " get grad place: " << recv_param_grad[1]
                << " place: " << op_dev_id;
@@ -819,7 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
       op_dev_id = GetAppropriateDeviceID(output_var_names);
     }
     for (auto &varname : output_var_names) {
-      sharded_var_device->emplace(varname, op_dev_id);
+      sharded_var_device_.emplace(varname, op_dev_id);
     }
   } else {
     // send_barrier, fetch_barrier will run on place 0;
@@ -846,7 +837,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
     for (ir::Node *output : node->outputs) {
       int outvar_dev_id = op_dev_id;
       if (node->Op()->Type() == "fetch_barrier") {
-        outvar_dev_id = GetVarDeviceID(output->Name(), *sharded_var_device);
+        outvar_dev_id = GetVarDeviceID(output->Name());
         PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
       }
       p = places_[outvar_dev_id];
@@ -863,29 +854,124 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
   return op_dev_id;
 }
 
-bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
-  PADDLE_ENFORCE(all_vars_.count(og) != 0);
-  if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
-    return true;
+int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
+                                           ir::Node *node) const {
+  int op_dev_id = -1;
+  std::vector<std::string> input_var_names;
+  std::vector<std::string> output_var_names;
+  for (ir::Node *input : node->inputs) {
+    input_var_names.push_back(input->Name());
   }
-  return false;
+  for (ir::Node *output : node->outputs) {
+    output_var_names.push_back(output->Name());
+  }
+
+  if (node->Op()->Type() == "split_byref" ||
+      node->Op()->Type() == "split_selected_rows" ||
+      node->Op()->Type() == "split_ids") {
+    // TODO(paddle-dev): getting the first var is not safe.
+    op_dev_id = GetVarDeviceID(input_var_names[0]);
+    if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
+      op_dev_id = GetAppropriateDeviceID(input_var_names);
+      for (auto &varname : input_var_names) {
+        sharded_var_device_.emplace(varname, op_dev_id);
+      }
+    }
+    for (auto &varname : output_var_names) {
+      sharded_var_device_.emplace(varname, op_dev_id);
+    }
+  } else if (node->Op()->Type() == "concat") {
+    op_dev_id = GetVarDeviceID(input_var_names[0]);
+    for (auto &varname : output_var_names) {
+      sharded_var_device_.emplace(varname, op_dev_id);
+    }
+  } else {
+    LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
+    PADDLE_THROW(
+        "the distribute training related op should be in [split_byref, "
+        "concat].");
+  }
+
+  PADDLE_ENFORCE(op_dev_id != -1,
+                 "can not find right place for distributed op: %s",
+                 node->Op()->Type());
+
+  CreateComputationalOp(result, node, op_dev_id);
+  return op_dev_id;
 }
 
-bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
-  return boost::get<int>(
-             node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
-             (static_cast<int>(OpRole::kBackward) |
-              static_cast<int>(OpRole::kLoss)) &&
-         !loss_var_name_.empty();  // If loss_var is empty. This is test mode
+void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
+                                             const std::string &p_name,
+                                             const std::string &g_name) const {
+  size_t cur_device_id = 0;
+  switch (strategy_.reduce_) {
+    case BuildStrategy::ReduceStrategy::kReduce:
+      cur_device_id = GetAppropriateDeviceID({g_name});
+      CreateReduceOp(result, g_name, cur_device_id);
+      sharded_var_device_.emplace(g_name, cur_device_id);
+      break;
+    case BuildStrategy::ReduceStrategy::kAllReduce:
+      if (IsSparseGradient(g_name)) {
+        CreateReduceOp(result, g_name, 0);
+        CreateBroadcastOp(result, g_name, 0);
+      } else {
+        CreateAllReduceOp(result, g_name);
+      }
+      break;
+    default:
+      LOG(FATAL) << "Unknown reduce strategy.";
+      break;
+  }
+}
+
+void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
+  if (need_broadcast_var_ ||
+      (UseGPU() &&
+       strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) {
+    if (strategy_.fuse_broadcast_op_) {
+      CreateFusedBroadcastOp(result, bcast_var_name_set_);
+    } else {
+      for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
+        auto &to_bcast_set = bcast_var_name_set_[dev_id];
+        for (auto &bcast_name : to_bcast_set) {
+          CreateBroadcastOp(result, bcast_name, dev_id);
+        }
+      }
+    }
+  }
+}
+
+std::unordered_set<std::string> &MultiDevSSAGraphBuilder() {
+  static std::unordered_set<std::string> regs;
+  return regs;
 }
+
+static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
+  MultiDevSSAGraphBuilder().insert(builder_mode);
+  return 0;
+}
+
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
 
-REGISTER_PASS(multi_devices_pass,
-              paddle::framework::details::MultiDevSSAGraphBuilder)
-    .RequirePassAttr(paddle::framework::details::kLossVarName)
-    .RequirePassAttr(paddle::framework::details::kPlaces)
-    .RequirePassAttr(paddle::framework::details::kLocalScopes)
-    .RequirePassAttr(paddle::framework::details::kStrategy)
-    .RequirePassAttr(paddle::framework::details::kNRanks);
+#define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class)                     \
+  STATIC_ASSERT_GLOBAL_NAMESPACE(                                              \
+      _reg_ssa_graph_builder_##pass_name,                                      \
+      "REGISTER_MULTI_DEVICES_PASS must be called in global namespace.");      \
+  int _reg_ssa_graph_builder_entry_##pass_name =                               \
+      paddle::framework::details::MultiDevSSAGraphBuilderRegister(#pass_name); \
+  REGISTER_PASS(pass_name, pass_class)                                         \
+      .RequirePassAttr(paddle::framework::details::kLossVarName)               \
+      .RequirePassAttr(paddle::framework::details::kPlaces)                    \
+      .RequirePassAttr(paddle::framework::details::kLocalScopes)               \
+      .RequirePassAttr(paddle::framework::details::kStrategy)                  \
+      .RequirePassAttr(paddle::framework::details::kNRanks)
+
+REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
+                            paddle::framework::details::ReduceSSAGraphBuilder);
+REGISTER_MULTI_DEVICES_PASS(
+    allreduce_mode_multi_devices_pass,
+    paddle::framework::details::AllReduceSSAGraphBuilder);
+REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
+                            paddle::framework::details::DistSSAGraphBuilder);
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h
index 7029e9dc18..6d4386538e 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #pragma once
+
 #include <string>
 #include <utility>
 #include <vector>
@@ -30,78 +31,70 @@ namespace framework {
 class Scope;
 namespace details {
 
-class MultiDevSSAGraphBuilder : public ir::Pass {
+constexpr char kLossVarName[] = "loss_var_name";
+constexpr char kPlaces[] = "places";
+constexpr char kLocalScopes[] = "local_scopes";
+constexpr char kStrategy[] = "strategy";
+constexpr char kNRanks[] = "nranks";
+
+class MultiDevSSAGraphBuilderBase : public ir::Pass {
  protected:
   std::unique_ptr<ir::Graph> ApplyImpl(
       std::unique_ptr<ir::Graph> graph) const override;
 
- private:
-  void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
-                         size_t device_id) const;
-  void Init() const;
+  virtual void Init() const;
 
-#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
-  mutable platform::NCCLContextMap *nccl_ctxs_;
-#endif
+  virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
 
-  int GetVarDeviceID(
-      const std::string &varname,
-      const std::unordered_map<std::string, int> &sharded_var_device) const;
+  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
+                                  const std::string &g_name) const = 0;
 
-  bool IsScaleLossOp(ir::Node *node) const;
+  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
+
+  virtual void InsertPostprocessOps(ir::Graph *result) const = 0;
 
-  int CreateRPCOp(
-      ir::Graph *result, ir::Node *node,
-      std::unordered_map<std::string, int> *sharded_var_device) const;
-  int CreateDistTrainOp(
-      ir::Graph *result, ir::Node *node,
-      std::unordered_map<std::string, int> *sharded_var_device) const;
+  bool UseGPU() const;
+
+  bool NeedCollectiveOps() const;
+
+  bool IsScaleLossOp(ir::Node *node) const;
 
   void CreateComputationalOps(ir::Graph *result, ir::Node *node,
                               size_t num_places) const;
 
   void CreateScaleLossGradOp(ir::Graph *result,
                              const std::string &loss_grad_name,
-                             ir::Node *out_var_node,
+                             ir::Node *out_var_node, size_t loss_scale,
                              proto::VarType::Type dtype) const;
 
   VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
                             int dst_dev_id) const;
+
   void CreateComputationalOp(ir::Graph *result, ir::Node *node,
                              int dev_id) const;
 
-  int GetOpDeviceID(
-      ir::Node *node,
-      const std::unordered_map<std::string, int> &sharded_var_device) const;
-
-  void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
+  bool IsSparseGradient(const std::string &og) const;
 
-  void InsertDataBalanceOp(ir::Graph *result,
-                           const std::vector<std::string> &datas) const;
+  void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
 
   void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
                          size_t src_dev_id) const;
 
+  void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const;
+
   void CreateFusedBroadcastOp(
       ir::Graph *result,
       const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
 
-  bool IsSparseGradient(const std::string &og) const;
-
-  size_t GetAppropriateDeviceID(
-      const std::vector<std::string> &var_names) const;
-
   void SetCommunicationContext(OpHandleBase *op_handle,
                                const platform::Place &p) const;
 
-  std::vector<ir::Node *> SortForReduceMode(
-      const std::vector<ir::Node *> &) const;
+  void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
+                         size_t device_id) const;
 
-  int GetOpDeviceID(
-      ir::Node *node,
-      const std::unordered_map<std::string, int> &shared_var_device,
-      std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops)
-      const;
+#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
+  mutable platform::NCCLContextMap *nccl_ctxs_;
+#endif
 
   mutable std::string loss_var_name_;
   mutable std::vector<platform::Place> places_;
@@ -109,8 +102,83 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
 
   mutable BuildStrategy strategy_;
   mutable std::unordered_map<std::string, VarDesc *> all_vars_;
+};
+
+class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
+ protected:
+  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
+                                  const std::string &g_name) const;
+
+  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
+    return false;
+  }
+
+  virtual void InsertPostprocessOps(ir::Graph *result) const {}
+};
+
+class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
+ protected:
+  int GetVarDeviceID(const std::string &varname) const;
+
+  int GetOpDeviceID(ir::Node *node) const;
+
+  size_t GetAppropriateDeviceID(
+      const std::vector<std::string> &var_names) const;
+
+  virtual void ResetState() const;
+
+  mutable std::unordered_map<std::string, int> sharded_var_device_;
   mutable std::vector<int64_t> balance_vars_;
 };
+
+class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
+ protected:
+  virtual void Init() const;
+
+  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
+                                  const std::string &g_name) const;
+
+  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
+
+  virtual void InsertPostprocessOps(ir::Graph *result) const;
+
+  virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
+
+  virtual void ResetState() const;
+
+  int GetOpDeviceID(ir::Node *node,
+                    std::unordered_map<std::string, std::vector<ir::Node *>>
+                        *delay_ops) const;
+
+  std::vector<ir::Node *> SortForReduceMode(
+      const std::vector<ir::Node *> &topo_ops) const;
+
+  mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
+};
+
+class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
+ protected:
+  virtual void Init() const;
+
+  virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
+
+  virtual void InsertPostprocessOps(ir::Graph *result) const;
+
+  virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
+                                  const std::string &g_name) const;
+
+  virtual void ResetState() const;
+
+  int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
+
+  int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
+
+  mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
+  mutable bool need_broadcast_var_{false};
+};
+
+std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
+
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index 3b81d59ad9..dce755c91a 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -946,13 +946,6 @@ All parameter, weight, gradient are variables in Paddle.
           R"DOC(The type is STR, debug_graphviz_path indicate the path that
                     writing the SSA Graph to file in the form of graphviz, you.
                     It is useful for debugging. Default "")DOC")
-      .def_property(
-          "enable_data_balance",
-          [](const BuildStrategy &self) { return self.enable_data_balance_; },
-          [](BuildStrategy &self, bool b) {
-            PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
-            self.enable_data_balance_ = b;
-          })  // FIXME(chengudo): enable_data_balance seems not important
       .def_property(
           "enable_sequential_execution",
           [](const BuildStrategy &self) {
@@ -1007,6 +1000,10 @@ All parameter, weight, gradient are variables in Paddle.
           "memory_optimize",
           [](const BuildStrategy &self) { return self.memory_optimize_; },
           [](BuildStrategy &self, bool b) { self.memory_optimize_ = b; })
+      .def_property(
+          "is_distribution",
+          [](const BuildStrategy &self) { return self.is_distribution_; },
+          [](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
       .def_property(
           "memory_early_delete",
           [](const BuildStrategy &self) { return self.memory_early_delete_; },
diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py
index c97a93ec36..3b066eda11 100644
--- a/python/paddle/fluid/parallel_executor.py
+++ b/python/paddle/fluid/parallel_executor.py
@@ -29,6 +29,15 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
 BuildStrategy = core.ParallelExecutor.BuildStrategy
 
 
+def _is_pserver_mode(main_program):
+    main = main_program if main_program \
+        else framework.default_main_program()
+    for op in main.global_block().ops:
+        if op.type in ["send", "recv"]:
+            return True
+    return False
+
+
 class ParallelExecutor(object):
     """
     ParallelExecutor is designed for data parallelism, which focuses on distributing
@@ -128,6 +137,11 @@ class ParallelExecutor(object):
             build_strategy = BuildStrategy()
         build_strategy.num_trainers = num_trainers
         build_strategy.trainer_id = trainer_id
+        # FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
+        # num_trainers is 1, so the current fields of build_strategy doesn't tell if
+        # it's distributed model.
+        build_strategy.is_distribution = _is_pserver_mode(
+            main_program) or num_trainers > 1
 
         # step4: get main_program, scope, local_scopes
         main = main_program if main_program \
diff --git a/python/paddle/fluid/tests/unittests/test_reader_reset.py b/python/paddle/fluid/tests/unittests/test_reader_reset.py
index e97a05b6f9..7eeffa1039 100644
--- a/python/paddle/fluid/tests/unittests/test_reader_reset.py
+++ b/python/paddle/fluid/tests/unittests/test_reader_reset.py
@@ -75,8 +75,6 @@ class TestReaderReset(unittest.TestCase):
         exe.run(startup_prog)
 
         build_strategy = fluid.BuildStrategy()
-        if with_double_buffer:
-            build_strategy.enable_data_balance = True
         exec_strategy = fluid.ExecutionStrategy()
         parallel_exe = fluid.ParallelExecutor(
             use_cuda=self.use_cuda,