From ff599b921885af7645858cc9e45a661e6807b864 Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Fri, 4 May 2018 23:29:59 +0800
Subject: [PATCH] use Reduce and Broadcast

---
 .../details/multi_devices_graph_builder.cc    | 62 +++----------------
 .../details/multi_devices_graph_builder.h     | 10 +--
 2 files changed, 13 insertions(+), 59 deletions(-)

diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 37100b529d..21197d587b 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
   for (auto *var : program.Block(0).AllVars()) {
     var_types[var->Name()] = var->GetType();
   }
+
   auto graph = new SSAGraph();
   SSAGraph &result = *graph;
   std::unordered_set<std::string> og_has_been_broadcast;
@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
       std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
       places_.size());
 
-  size_t cur_dev_id = 0;
-  std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
-  std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
-
-  sparse_var_name_on_devices.resize(places_.size());
-  bcast_sparse_var_name_set.resize(places_.size());
-
   // Find "send" op first for split is in front of send.
   OpDesc *send_op = GetSendOpDesc(program);
 
@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
       }
       is_forwarding = false;
     } else {
-      int op_dev_id = GetOpDeviceID(sparse_var_name_on_devices, *op);
-      if (op_dev_id == -1) {  // var on all device
-        CreateComputationalOps(&result, *op, places_.size());
-      } else {
-        CreateComputationalOp(&result, *op, op_dev_id);
-        for (auto &var_name : op->OutputArgumentNames()) {
-          sparse_var_name_on_devices[op_dev_id].emplace(var_name);
-        }
-      }
-
+      CreateComputationalOps(&result, *op, places_.size());
       if (!is_forwarding && places_.size() > 1) {
         // Currently, we assume that once gradient is generated, it can be
         // broadcast, and each gradient is only broadcast once.
         for (auto &og : op->OutputArgumentNames()) {
           if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
             if (IsSparseGradient(var_types, og)) {
-              CreateReduceOp(&result, cur_dev_id, og);
-              sparse_var_name_on_devices[cur_dev_id].emplace(og);
-              bcast_sparse_var_name_set[cur_dev_id].emplace(
-                  og.substr(0, og.size() - strlen(kGradVarSuffix)));
-              cur_dev_id = (cur_dev_id + 1) % places_.size();
+              CreateReduceOp(&result, og, 0);
+              CreateBroadcastOp(&result, og, 0);
             } else {
               InsertNCCLAllReduceOp(&result, og);
             }
@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
     }
   }
 
-  // Insert BCast Ops
-  for (size_t dev_id = 0; dev_id < bcast_sparse_var_name_set.size(); ++dev_id) {
-    auto &to_bcast_set = bcast_sparse_var_name_set[dev_id];
-    for (auto &bcast_name : to_bcast_set) {
-      CreateBroadcastOp(&result, bcast_name, dev_id);
-    }
-  }
-
   /*
     Dependency graph has been constructed. However, there are still data
     harzaeds need to be handled.
@@ -213,26 +187,9 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
   return false;
 }
 
-int MultiDevSSAGraphBuilder::GetOpDeviceID(
-    const std::vector<std::unordered_set<std::string>>
-        &sparse_var_name_on_devices,
-    const OpDesc &op) const {
-  int var_dev_id = -1;
-  for (auto &var_name : op.InputArgumentNames()) {
-    if (var_dev_id != -1) break;
-    for (size_t i = 0; i < sparse_var_name_on_devices.size(); ++i) {
-      if (sparse_var_name_on_devices[i].count(var_name)) {
-        var_dev_id = static_cast<int>(i);
-        break;
-      }
-    }
-  }
-  return var_dev_id;
-}
-
 void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
                                                 const std::string &p_name,
-                                                size_t dev_id) const {
+                                                size_t src_dev_id) const {
 #ifdef PADDLE_WITH_CUDA
   auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
 #else
@@ -240,11 +197,11 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
 #endif
 
   result->ops_.emplace_back(op_handle);
-  auto *in = result->vars_.at(dev_id).at(p_name).back().get();
+  auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
   op_handle->AddInput(in);
 
   for (size_t i = 0; i < places_.size(); ++i) {
-    auto &vars = result->vars_.at(dev_id).at(p_name);
+    auto &vars = result->vars_.at(i).at(p_name);
     auto &p = places_[i];
     auto *out_var = new VarHandle(vars.size(), i, p_name, p);
     vars.emplace_back(out_var);
@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
   }
 }
 
-VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(
-    SSAGraph *result, int dst_dev_id, const std::string &og) const {
+VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
+                                                   const std::string &og,
+                                                   int dst_dev_id) const {
 #ifdef PADDLE_WITH_CUDA
   result->ops_.emplace_back(
       new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 1672958b22..674e2779a1 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -75,8 +75,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
                               size_t num_places) const;
 
   void CreateScaleLossGradOp(SSAGraph *result) const;
-  VarHandle *CreateReduceOp(SSAGraph *result, int dst_dev_id,
-                            const std::string &og) const;
+  VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
+                            int dst_dev_id) const;
   void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
                              int dev_id) const;
 
@@ -87,11 +87,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
   void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
 
   void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
-                         size_t dev_id) const;
-
-  int GetOpDeviceID(
-      const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
-      const OpDesc &op) const;
+                         size_t src_dev_id) const;
 
   /**
    * Get send op in the global block of program.