From 093d227a7796e50dc2f7a04094b4725c6f40f399 Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Mon, 16 Apr 2018 10:33:01 +0800
Subject: [PATCH] Use mutex to stablize ncclCtxMap

---
 paddle/fluid/platform/nccl_helper.h | 50 +++++++++--------------------
 1 file changed, 16 insertions(+), 34 deletions(-)

diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h
index ca9ab2c7ae..0013597fd5 100644
--- a/paddle/fluid/platform/nccl_helper.h
+++ b/paddle/fluid/platform/nccl_helper.h
@@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
 
 class NCCLGroupGuard {
  public:
+  static std::mutex &NCCLMutex() {
+    static std::mutex mtx;
+    return mtx;
+  }
+
   inline NCCLGroupGuard() {
-    mutex().lock();
+    NCCLMutex().lock();
     PADDLE_ENFORCE(dynload::ncclGroupStart());
   }
 
   inline ~NCCLGroupGuard() {
     PADDLE_ENFORCE(dynload::ncclGroupEnd());
-    mutex().unlock();
-  }
-
- private:
-  static std::mutex &mutex() {
-    static std::mutex mtx;
-    return mtx;
+    NCCLMutex().unlock();
   }
 };
 
@@ -68,26 +67,6 @@ struct NCCLContext {
   int device_id() const {
     return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
   }
-
-  static void InitNCCLContext(std::unordered_map<int, NCCLContext> *contexts,
-                              const std::vector<platform::Place> &places) {
-    std::vector<ncclComm_t> comms;
-    std::vector<int> devs;
-    comms.resize(contexts->size());
-    devs.reserve(contexts->size());
-
-    for (auto &p : places) {
-      devs.push_back(boost::get<platform::CUDAPlace>(p).device);
-    }
-
-    PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
-        &comms[0], static_cast<int>(contexts->size()), &devs[0]));
-
-    int i = 0;
-    for (auto &dev_id : devs) {
-      contexts->at(dev_id).comm_ = comms[i++];
-    }
-  }
 };
 
 struct NCCLContextMap {
@@ -107,12 +86,12 @@ struct NCCLContextMap {
         "NCCL Context Map does not support contain two or more same device");
 
     if (places.size() > 1) {
-      std::vector<ncclComm_t> comms;
-      comms.resize(order_.size());
-
-      PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
-          &comms[0], static_cast<int>(order_.size()), &order_[0]));
-
+      std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
+      {
+        std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
+        PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
+            comms.get(), static_cast<int>(order_.size()), order_.data()));
+      }
       int i = 0;
       for (auto &dev_id : order_) {
         contexts_.at(dev_id).comm_ = comms[i++];
@@ -120,6 +99,9 @@ struct NCCLContextMap {
     }
   }
 
+  NCCLContextMap(const NCCLContextMap &other) = delete;
+  NCCLContextMap &operator=(const NCCLContextMap &other) = delete;
+
   CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
 
   CUDADeviceContext *DevCtx(platform::Place p) const {