From fe7ed285d131ba99e82538e76cb7ac5381e97809 Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Wed, 21 Mar 2018 14:49:02 +0800
Subject: [PATCH] Extract NCCLCtxMap

---
 paddle/fluid/framework/CMakeLists.txt         |   2 +-
 paddle/fluid/framework/details/CMakeLists.txt |   1 +
 .../fluid/framework/details/op_handle_base.cc |  84 +++++++++++++
 .../fluid/framework/details/op_handle_base.h  |  48 ++++++++
 paddle/fluid/framework/details/var_handle.h   |   4 +-
 paddle/fluid/framework/parallel_executor.cc   | 114 +++---------------
 paddle/fluid/platform/nccl_helper.h           |  46 +++++++
 7 files changed, 196 insertions(+), 103 deletions(-)
 create mode 100644 paddle/fluid/framework/details/op_handle_base.cc
 create mode 100644 paddle/fluid/framework/details/op_handle_base.h

diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index 9d2dc29028..afc7ec9d66 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
 cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
 framework_proto backward glog lod_rank_table feed_fetch_method)
 cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
-        framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool var_handle)
+        framework_proto backward glog lod_rank_table simple_threadpool var_handle op_handle_base)
 
 cc_library(prune SRCS prune.cc DEPS framework_proto)
 cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index 5074715e2e..d9bdf0b94d 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -1 +1,2 @@
 cc_library(var_handle SRCS var_handle.cc DEPS place)
+cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc
new file mode 100644
index 0000000000..094b62cc94
--- /dev/null
+++ b/paddle/fluid/framework/details/op_handle_base.cc
@@ -0,0 +1,84 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/details/op_handle_base.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+std::string OpHandleBase::DebugString() const {
+  std::stringstream ss;
+  ss << "(";
+  for (auto *var : inputs_) {
+    ss << var->DebugString() << ", ";
+  }
+  ss << ") --> (";
+  for (auto *var : outputs_) {
+    ss << var->DebugString() << ", ";
+  }
+  ss << ")\n";
+  return ss.str();
+}
+
+OpHandleBase::~OpHandleBase() {}
+
+void OpHandleBase::Run(bool use_event) {
+#ifdef PADDLE_WITH_CUDA
+  if (events_.empty() && use_event) {
+    for (auto &p : dev_ctx_) {
+      int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
+      cudaSetDevice(dev_id);
+      cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
+    }
+  }
+#else
+  PADDLE_ENFORCE(!use_event);
+#endif
+
+  RunImpl();
+
+#ifdef PADDLE_WITH_CUDA
+  if (use_event) {
+    for (auto &p : dev_ctx_) {
+      int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
+      auto stream =
+          static_cast<platform::CUDADeviceContext *>(p.second)->stream();
+      cudaEventRecord(events_.at(dev_id), stream);
+    }
+  }
+#endif
+}
+
+void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
+#ifdef PADDLE_WITH_CUDA
+  if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
+    for (auto &dev_ctx : dev_ctx_) {
+      dev_ctx.second->Wait();
+    }
+  } else {
+    auto stream =
+        static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
+    for (auto &ev : events_) {
+      PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
+    }
+  }
+#else
+  for (auto &dev_ctx : dev_ctx_) {
+    dev_ctx.second->Wait();
+  }
+#endif
+}
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h
new file mode 100644
index 0000000000..bdfd1f78ad
--- /dev/null
+++ b/paddle/fluid/framework/details/op_handle_base.h
@@ -0,0 +1,48 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/fluid/framework/details/var_handle.h"
+#include "paddle/fluid/platform/device_context.h"
+namespace paddle {
+namespace framework {
+namespace details {
+
+struct OpHandleBase {
+  std::vector<VarHandleBase *> inputs_;
+  std::vector<VarHandleBase *> outputs_;
+  std::unordered_map<platform::Place, platform::DeviceContext *,
+                     platform::PlaceHash>
+      dev_ctx_;
+
+#ifdef PADDLE_WITH_CUDA
+  std::unordered_map<int, cudaEvent_t> events_;
+#endif
+
+  std::string DebugString() const;
+
+  virtual ~OpHandleBase();
+
+  void Run(bool use_event);
+
+  virtual void Wait(platform::DeviceContext *waited_dev);
+
+ protected:
+  virtual void RunImpl() = 0;
+};
+
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h
index 613ff901b1..893cc15f6c 100644
--- a/paddle/fluid/framework/details/var_handle.h
+++ b/paddle/fluid/framework/details/var_handle.h
@@ -21,10 +21,8 @@
 
 namespace paddle {
 namespace framework {
-
-struct OpHandleBase;
-
 namespace details {
+struct OpHandleBase;
 
 // VarHandleBase is the var node in the dependency graph.
 // A variable can only be generated by a single operator. i.e.
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 2b094eba1e..3c24fa4bdf 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -14,86 +14,22 @@ limitations under the License. */
 
 #include "paddle/fluid/framework/parallel_executor.h"
 #include "ThreadPool.h"
-#include "executor.h"
 #include "lod_tensor.h"
 #include "lod_tensor_array.h"
 #include "op_registry.h"
+#include "paddle/fluid/framework/details/op_handle_base.h"
 #include "paddle/fluid/framework/details/var_handle.h"
 #include "paddle/fluid/framework/feed_fetch_type.h"
-#include "paddle/fluid/operators/math/concat.h"
 #include "paddle/fluid/platform/nccl_helper.h"
 
 namespace paddle {
 namespace framework {
 
 using details::DummyVarHandle;
+using details::OpHandleBase;
 using details::VarHandle;
 using details::VarHandleBase;
 
-struct OpHandleBase {
-  std::vector<VarHandleBase *> inputs_;
-  std::vector<VarHandleBase *> outputs_;
-  std::unordered_map<platform::Place, platform::DeviceContext *,
-                     platform::PlaceHash>
-      dev_ctx_;
-
-  std::unordered_map<int, cudaEvent_t> events_;
-
-  std::string DebugString() {
-    std::stringstream ss;
-    ss << "(";
-    for (auto *var : inputs_) {
-      ss << var->DebugString() << ", ";
-    }
-    ss << ") --> (";
-    for (auto *var : outputs_) {
-      ss << var->DebugString() << ", ";
-    }
-    ss << ")\n";
-    return ss.str();
-  }
-
-  virtual ~OpHandleBase() {}
-
-  void Run(bool use_event) {
-    if (events_.empty() && use_event) {
-      for (auto &p : dev_ctx_) {
-        int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
-        cudaSetDevice(dev_id);
-        cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
-      }
-    }
-
-    RunImpl();
-
-    if (use_event) {
-      for (auto &p : dev_ctx_) {
-        int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
-        auto stream =
-            static_cast<platform::CUDADeviceContext *>(p.second)->stream();
-        cudaEventRecord(events_.at(dev_id), stream);
-      }
-    }
-  }
-
-  virtual void Wait(platform::DeviceContext *waited_dev) {
-    if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
-      for (auto &dev_ctx : dev_ctx_) {
-        dev_ctx.second->Wait();
-      }
-    } else {
-      auto stream =
-          static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
-      for (auto &ev : events_) {
-        PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
-      }
-    }
-  }
-
- protected:
-  virtual void RunImpl() = 0;
-};
-
 struct ScaleLossGradOpHandle : public OpHandleBase {
   float coeff_;
   Scope *scope_;
@@ -193,12 +129,7 @@ class ParallelExecutorPrivate {
   std::vector<Scope *> local_scopes_;
   Scope *global_scope_;
 
-  std::unordered_map<int, platform::NCCLContext> communication_streams_;
-
-  platform::NCCLContext &GetNCCLCtx(platform::Place p) {
-    int dev_id = boost::get<platform::CUDAPlace>(p).device;
-    return communication_streams_.at(dev_id);
-  }
+  std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
 
   platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
     if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
@@ -206,7 +137,7 @@ class ParallelExecutorPrivate {
           platform::DeviceContextPool::Instance().Get(place));
     } else {
 #ifdef PADDLE_WITH_CUDA
-      return GetNCCLCtx(place).ctx_.get();
+      return nccl_ctxs_->DevCtx(place);
 #else
       PADDLE_THROW("Not compiled with CUDA")
 #endif
@@ -293,15 +224,12 @@ class ParallelExecutorPrivate {
 struct NCCLAllReduceOpHandle : public OpHandleBase {
   const std::vector<Scope *> &local_scopes_;
   const std::vector<platform::Place> &places_;
-  const std::unordered_map<int, platform::NCCLContext> &communication_ctxs_;
+  const platform::NCCLContextMap &nccl_ctxs_;
 
-  explicit NCCLAllReduceOpHandle(
-      const std::vector<Scope *> &local_scopes,
-      const std::vector<platform::Place> &places,
-      const std::unordered_map<int, platform::NCCLContext> &ctxs)
-      : local_scopes_(local_scopes),
-        places_(places),
-        communication_ctxs_(ctxs) {}
+  explicit NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
+                                 const std::vector<platform::Place> &places,
+                                 const platform::NCCLContextMap &ctxs)
+      : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {}
 
   void Wait(platform::DeviceContext *waited_dev) override {
     OpHandleBase::Wait(waited_dev);
@@ -343,7 +271,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
         if (numel == 0) {
           numel = static_cast<size_t>(lod_tensor.numel());
         }
-        auto &nccl_ctx = communication_ctxs_.at(dev_id);
+        auto &nccl_ctx = nccl_ctxs_.at(dev_id);
         PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
             buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
             nccl_ctx.comm_, nccl_ctx.stream()));
@@ -491,8 +419,7 @@ void ParallelExecutor::ConstructDependencyGraph(
         if (grads.count(og) != 0) {  // is param grad
           // Insert NCCL AllReduce Op
           member_->ops_.emplace_back(new NCCLAllReduceOpHandle(
-              member_->local_scopes_, member_->places_,
-              member_->communication_streams_));
+              member_->local_scopes_, member_->places_, *member_->nccl_ctxs_));
           auto *op_handle = member_->ops_.back().get();
 
           for (size_t i = 0; i < member_->places_.size(); ++i) {
@@ -598,15 +525,12 @@ void ParallelExecutor::BCastParamsToGPUs(
           buffer = t->mutable_data(place, main_tensor.type());
         }
 
-        auto &nccl_ctx = member_->GetNCCLCtx(place);
+        auto &nccl_ctx = member_->nccl_ctxs_->at(place);
         platform::dynload::ncclBcast(buffer, numel, data_type, 0,
                                      nccl_ctx.comm_, nccl_ctx.stream());
       }
     }
-
-    for (auto &stream : member_->communication_streams_) {
-      stream.second.ctx_->Wait();
-    }
+    member_->nccl_ctxs_->WaitAll();
   }
 #else
   PADDLE_THROW("Not compiled with CUDA");
@@ -615,15 +539,7 @@ void ParallelExecutor::BCastParamsToGPUs(
 
 void ParallelExecutor::BuildNCCLCommunicator() const {
 #ifdef PADDLE_WITH_CUDA
-  for (auto &place : member_->places_) {
-    int dev_id = boost::get<platform::CUDAPlace>(place).device;
-
-    member_->communication_streams_.emplace(dev_id,
-                                            platform::NCCLContext(dev_id));
-  }
-
-  platform::NCCLContext::InitNCCLContext(member_->communication_streams_,
-                                         member_->places_);
+  member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
 #endif
 }
 
@@ -682,7 +598,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
     op->offset_ = i;
     op->local_scopes_ = &member_->local_scopes_;
     for (auto &p : member_->places_) {
-      op->dev_ctx_[p] = member_->GetNCCLCtx(p).ctx_.get();
+      op->dev_ctx_[p] = member_->nccl_ctxs_->DevCtx(p);
     }
 
     for (auto *var : vars) {
diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h
index 3db846b024..2999004320 100644
--- a/paddle/fluid/platform/nccl_helper.h
+++ b/paddle/fluid/platform/nccl_helper.h
@@ -87,5 +87,51 @@ struct NCCLContext {
   }
 };
 
+struct NCCLContextMap {
+  std::unordered_map<int, NCCLContext> contexts_;
+  std::vector<int> order_;
+
+  NCCLContextMap(const std::vector<platform::Place> &places) {
+    order_.reserve(places.size());
+    for (auto &p : places) {
+      int dev_id = boost::get<CUDAPlace>(p).device;
+      order_.emplace_back(dev_id);
+      contexts_.emplace(dev_id, NCCLContext(dev_id));
+    }
+    PADDLE_ENFORCE_EQ(
+        order_.size(), contexts_.size(),
+        "NCCL Context Map does not support contain two or more same device");
+
+    std::vector<ncclComm_t> comms;
+    comms.resize(order_.size());
+
+    PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
+        &comms[0], static_cast<int>(order_.size()), &order_[0]));
+
+    int i = 0;
+    for (auto &dev_id : order_) {
+      contexts_.at(dev_id).comm_ = comms[i++];
+    }
+  }
+
+  CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
+
+  CUDADeviceContext *DevCtx(platform::Place p) const {
+    return DevCtx(boost::get<CUDAPlace>(p).device);
+  }
+
+  const NCCLContext &at(platform::Place p) const {
+    return this->at(boost::get<CUDAPlace>(p).device);
+  }
+
+  const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
+
+  void WaitAll() {
+    for (auto &p : contexts_) {
+      p.second.ctx_->Wait();
+    }
+  }
+};
+
 }  // namespace platform
 }  // namespace paddle