From c99fca5f90ede6297eaf5b4c9617ad21df8b445d Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Thu, 21 Jun 2018 12:25:31 +0800
Subject: [PATCH] Add No Mutex

---
 .../framework/details/broadcast_op_handle.cc  | 57 ++++++++++++++-----
 .../fluid/framework/details/op_handle_base.cc | 23 ++++++++
 .../fluid/framework/details/op_handle_base.h  |  4 ++
 .../framework/details/reduce_op_handle.cc     |  4 +-
 paddle/fluid/platform/device_context.h        |  8 +++
 5 files changed, 80 insertions(+), 16 deletions(-)

diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc
index 1d9f1bd6e4..b0bf641d9d 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle.cc
@@ -103,23 +103,50 @@ void BroadcastOpHandle::RunImpl() {
           });
     }
 
-    this->RunAndRecordEvent([&] {
-      {
-        platform::NCCLGroupGuard guard;
-        for (auto &call : broadcast_calls) {
-          call();
+    // FIXME(zcd): a temporary fix for some language model that has sparse
+    // parameter.
+    bool use_mutex = true;
+    if (in_var->IsType<paddle::framework::SelectedRows>()) {
+      use_mutex = false;
+    }
+    if (use_mutex) {
+      this->RunAndRecordEvent([&] {
+        {
+          platform::NCCLGroupGuard guard;
+          for (auto &call : broadcast_calls) {
+            call();
+          }
         }
-      }
 
-      if (!out_handle->IsTheSameVar(*in_var_handle)) {
-        auto out_var = var_scopes.at(in_var_handle->scope_idx_)
-                           ->FindVar(out_var_handles[0]->name_);
-        paddle::framework::TensorCopy(
-            in_tensor, in_var_handle->place_,
-            *(dev_ctxes_.at(in_var_handle->place_)),
-            &VariableVisitor::GetMutableTensor(out_var));
-      }
-    });
+        if (!out_handle->IsTheSameVar(*in_var_handle)) {
+          auto out_var = var_scopes.at(in_var_handle->scope_idx_)
+                             ->FindVar(out_var_handles[0]->name_);
+          paddle::framework::TensorCopy(
+              in_tensor, in_var_handle->place_,
+              *(dev_ctxes_.at(in_var_handle->place_)),
+              &VariableVisitor::GetMutableTensor(out_var));
+        }
+      });
+    } else {
+      this->RunAndRecordEventNoMutex([&] {
+        {
+          platform::NCCLGroupGuard guard;
+          for (auto &call : broadcast_calls) {
+            call();
+          }
+        }
+
+        if (!out_handle->IsTheSameVar(*in_var_handle)) {
+          auto out_var = var_scopes.at(in_var_handle->scope_idx_)
+                             ->FindVar(out_var_handles[0]->name_);
+          paddle::framework::TensorCopy(
+              in_tensor, in_var_handle->place_,
+              *(dev_ctxes_.at(in_var_handle->place_)),
+              &VariableVisitor::GetMutableTensor(out_var));
+        }
+      });
+    }
+
 #else
     PADDLE_THROW("CUDA is not enabled.");
 #endif
diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc
index f79565fe71..a40a881508 100644
--- a/paddle/fluid/framework/details/op_handle_base.cc
+++ b/paddle/fluid/framework/details/op_handle_base.cc
@@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
 #endif
 }
 
+void OpHandleBase::RunAndRecordEventNoMutex(
+    const std::function<void()> &callback) {
+#ifdef PADDLE_WITH_CUDA
+  if (!events_.empty()) {  // Use event
+    std::function<void()> method = callback;
+
+    for (auto &p : dev_ctxes_) {
+      method = [method, p, this]() {
+        static_cast<platform::CUDADeviceContext *>(p.second)
+            ->RecordEventNoMutex(
+                events_.at(boost::get<platform::CUDAPlace>(p.first).device),
+                method);
+      };
+    }
+    method();
+  } else {
+#endif
+    callback();
+#ifdef PADDLE_WITH_CUDA
+  }
+#endif
+}
+
 void OpHandleBase::RunAndRecordEvent(platform::Place p,
                                      const std::function<void()> &callback) {
 #ifdef PADDLE_WITH_CUDA
diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h
index fbd90a3296..775be0233a 100644
--- a/paddle/fluid/framework/details/op_handle_base.h
+++ b/paddle/fluid/framework/details/op_handle_base.h
@@ -85,6 +85,10 @@ class OpHandleBase {
  protected:
   void RunAndRecordEvent(const std::function<void()> &callback);
 
+  // FIXME(zcd): A temporary fix for some language model that has sparse
+  // parameter.
+  void RunAndRecordEventNoMutex(const std::function<void()> &callback);
+
   void RunAndRecordEvent(platform::Place p,
                          const std::function<void()> &callback);
 
diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc
index 7160e346da..9a626c890f 100644
--- a/paddle/fluid/framework/details/reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/reduce_op_handle.cc
@@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() {
   }
 
   if (pre_in_var->IsType<framework::SelectedRows>()) {
-    this->RunAndRecordEvent([&] {
+    // FIXME(zcd): A temporary fix for some language model that has sparse
+    // parameter.
+    this->RunAndRecordEventNoMutex([&] {
       std::vector<const SelectedRows *> in_selected_rows =
           GetInputValues<SelectedRows>(in_var_handles, var_scopes);
       GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h
index 292ffef1ae..d37e5ee578 100644
--- a/paddle/fluid/platform/device_context.h
+++ b/paddle/fluid/platform/device_context.h
@@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext {
     PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
   }
 
+  // FIXME(zcd): A temporary fix for some language model that has sparse
+  // parameter.
+  template <typename Callback>
+  void RecordEventNoMutex(cudaEvent_t ev, Callback callback) {
+    callback();
+    PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
+  }
+
  private:
   CUDAPlace place_;