From 881e063ee292eb13594147d65c4f39f3cade38fb Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Sat, 5 May 2018 14:53:52 +0800
Subject: [PATCH] follow comments

---
 .../framework/details/broadcast_op_handle.cc  | 35 +++++++++----------
 .../framework/details/gather_op_handle.cc     | 27 +++++++-------
 .../framework/details/reduce_op_handle.cc     | 24 +++++++------
 .../framework/details/ssa_graph_builder.h     |  4 ---
 paddle/fluid/framework/details/var_handle.h   | 11 +-----
 .../framework/details/variable_visitor.cc     |  4 +--
 6 files changed, 46 insertions(+), 59 deletions(-)

diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc
index 327409914e..2afa47c81b 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle.cc
@@ -53,42 +53,39 @@ void BroadcastOpHandle::RunImpl() {
 
   Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
 
-  // NOTE(zcd): the Place of input can get from in_tensor and in_var_handle ,
-  // maybe they are different, because the Place that getting from in_tensor is
-  // determined at runtime, the other is determined at building SSA graph stage.
-  // If they are different, DataTransform should be applied. Currently, it has
-  // not been done yet.
+  // NOTE: The tensors' Place of input and output must be all on GPU or all on
+  // CPU.
   for (auto *out_var_handle : out_var_handles) {
-    if (*out_var_handle == *in_var_handle) {
+    if (out_var_handle->IsTheSameVar(*in_var_handle)) {
       continue;
     }
-    auto &out_p = out_var_handle->place_;
+    auto t_out_p = out_var_handle->place_;
     auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
                         ->FindVar(out_var_handle->name_);
     PADDLE_ENFORCE_NOT_NULL(out_var);
-    PADDLE_ENFORCE_EQ(
-        out_p.which(), in_tensor.place().which(),
-        "Currently, Places of input and output must be all on CPU "
-        "or all on GPU.");
+    if (platform::is_gpu_place(in_tensor.place())) {
+      PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
+                     "Places of input and output must be all on GPU.");
+    } else {
+      t_out_p = platform::CPUPlace();
+    }
     VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
-    VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
+    VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
                                                             in_tensor.type());
   }
 
   if (platform::is_cpu_place(in_tensor.place())) {
     for (auto *out_var_handle : out_var_handles) {
-      if (*out_var_handle == *in_var_handle) {
+      if (out_var_handle->IsTheSameVar(*in_var_handle)) {
         continue;
       }
-
       auto &out_p = out_var_handle->place_;
-      auto dev_ctx = dev_ctxes_.at(out_p);
       auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
                           ->FindVar(out_var_handle->name_);
 
-      RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
+      RunAndRecordEvent(out_p, [in_tensor, out_var] {
         paddle::framework::TensorCopy(
-            in_tensor, out_p, *dev_ctx,
+            in_tensor, platform::CPUPlace(),
             &VariableVisitor::GetMutableTensor(out_var));
       });
     }
@@ -134,8 +131,8 @@ void BroadcastOpHandle::RunImpl() {
           call();
         }
       }
-      // TODO(zcd): Maybe the unequal operator is not appropriate here.
-      if (*out_handle != *in_var_handle) {
+
+      if (!out_handle->IsTheSameVar(*in_var_handle)) {
         auto out_var = var_scopes.at(in_var_handle->scope_idx_)
                            ->FindVar(out_var_handles[0]->name_);
         paddle::framework::TensorCopy(
diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc
index 021703f1e9..3dfc972a44 100644
--- a/paddle/fluid/framework/details/gather_op_handle.cc
+++ b/paddle/fluid/framework/details/gather_op_handle.cc
@@ -75,14 +75,15 @@ void GatherOpHandle::RunImpl() {
     in_tensors.emplace_back(in_sr_value.value());
   }
 
-  // TODO(zcd): The Place of var_handle is determined at building SSA graph
-  // stage, while the Place of var is determined at runtime. If they are
-  // different, DataTransform should be applied. Currently, it has not been done
-  // yet.
-  auto &out_place = out_var_handle->place_;
-  PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(),
-                    "Currently, Places of input and output must be all on CPU "
-                    "or all on GPU.");
+  // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
+  platform::Place t_out_p = out_var_handle->place_;
+  if (platform::is_gpu_place(pre_in_value.place())) {
+    PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
+                   "Places of input and output must be all on GPU.");
+  } else {
+    t_out_p = platform::CPUPlace();
+  }
+
   auto out_var =
       var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
   PADDLE_ENFORCE_NOT_NULL(out_var);
@@ -93,18 +94,18 @@ void GatherOpHandle::RunImpl() {
   DDim out_dim = pre_in_value.GetCompleteDims();
   out_dim[0] = static_cast<int64_t>(rows);
   out_value->mutable_value()->Resize(out_dim).mutable_data(
-      out_place, pre_in_value.value().type());
+      t_out_p, pre_in_value.value().type());
   Tensor *out_tensor = out_value->mutable_value();
 
   // copy
-  auto dev_ctx = dev_ctxes_[out_place];
-  RunAndRecordEvent(out_place, [in_tensors, out_tensor, &dev_ctx, out_place] {
+  auto dev_ctx = dev_ctxes_[out_var_handle->place_];
+  RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
+                                             t_out_p] {
     int s = 0, e = 0;
     for (size_t j = 0; j < in_tensors.size(); ++j) {
       e += in_tensors[j].dims()[0];
       auto sub_out = out_tensor->Slice(s, e);
-      paddle::framework::TensorCopy(in_tensors[j], out_place, *dev_ctx,
-                                    &sub_out);
+      paddle::framework::TensorCopy(in_tensors[j], t_out_p, *dev_ctx, &sub_out);
       s = e;
     }
   });
diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc
index 5ee7008b5b..1bb04c1dfc 100644
--- a/paddle/fluid/framework/details/reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/reduce_op_handle.cc
@@ -53,6 +53,7 @@ void ReduceOpHandle::RunImpl() {
   // Wait input done, this Wait is asynchronous operation
   WaitInputVarGenerated(in_var_handles);
 
+  // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
   std::vector<platform::Place> in_places;  // used to get dev_ctx
   for (auto *in_handle : in_var_handles) {
     in_places.emplace_back(in_handle->place_);
@@ -66,22 +67,23 @@ void ReduceOpHandle::RunImpl() {
       var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
   PADDLE_ENFORCE_NOT_NULL(out_var);
 
-  // TODO(zcd): The Place of var_handle is determined at building SSA graph
-  // stage, while the Place of var is determined at runtime. If they are
-  // different, DataTransform should be applied. Currently, it has not been done
-  // yet.
-  PADDLE_ENFORCE_EQ(
-      VariableVisitor::GetMutableTensor(pre_in_var).place().which(),
-      out_var_handle->place_.which(),
-      "Currently, Places of input and output must be all on CPU or all on "
-      "GPU.");
+  // NOTE: The tensors' Place of input and output must be all on GPU or all on
+  // CPU.
+  auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place();
+  platform::Place t_out_p;
+  if (platform::is_gpu_place(in_p)) {
+    PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_),
+                   "Places of input and output must be all on GPU.");
+    t_out_p = out_var_handle->place_;
+  } else {
+    t_out_p = platform::CPUPlace();
+  }
 
   if (pre_in_var->IsType<framework::SelectedRows>()) {
     std::vector<const SelectedRows *> in_selected_rows =
         GetInputValues<SelectedRows>(in_var_handles, var_scopes);
 
-    GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_,
-                       out_var_handle->place_,
+    GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
                        out_var->GetMutable<framework::SelectedRows>());
   } else {
     std::vector<const LoDTensor *> lod_tensors =
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h
index dafd4e8d6b..64e5d93081 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.h
+++ b/paddle/fluid/framework/details/ssa_graph_builder.h
@@ -48,10 +48,6 @@ class SSAGraphBuilder {
                                                const platform::Place &place,
                                                size_t place_offset);
 
-  static VarHandle *GetLatestVarHandle(SSAGraph *graph,
-                                       const std::string &each_var_name,
-                                       size_t place_offset);
-
   // Add an output variable (each_var_name, place, place_offset) to op_handle,
   // which belongs to graph
   static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h
index 7f30a6573b..cae9af7217 100644
--- a/paddle/fluid/framework/details/var_handle.h
+++ b/paddle/fluid/framework/details/var_handle.h
@@ -62,19 +62,10 @@ struct VarHandle : public VarHandleBase {
   std::string name_;
   platform::Place place_;
 
-  // NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
-  // member variables(version_, scope_id_, name_, place_) must be equal. But
-  // sometimes judging whether the two var_handle is equal is actually to
-  // determine whether the two Variables that represented by var_handle is the
-  // same. And the same Variable may have many different var_handles, the
-  // version_ of these var_handles is different. So I don't take care of
-  // version_ temporarily when overloading equal.
-  bool operator==(const VarHandle& o) const {
+  bool IsTheSameVar(const VarHandle& o) const {
     return o.generated_op_ == generated_op_ && o.name_ == name_ &&
            o.scope_idx_ == scope_idx_;
   }
-
-  bool operator!=(const VarHandle& o) const { return !this->operator==(o); }
 };
 
 // Dummy Variable. It is used to represent dependencies between operators
diff --git a/paddle/fluid/framework/details/variable_visitor.cc b/paddle/fluid/framework/details/variable_visitor.cc
index 99487a304f..3dfd14419d 100644
--- a/paddle/fluid/framework/details/variable_visitor.cc
+++ b/paddle/fluid/framework/details/variable_visitor.cc
@@ -88,7 +88,7 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
   VisitVariable(src, &visitor);
 }
 
-struct EnforceEqualShapeAndDTypeVisitor {
+struct EnforceShapeAndDTypeEQVisitor {
   const Variable* trg_;
 
   void operator()(const LoDTensor& src) {
@@ -130,7 +130,7 @@ struct EnforceEqualShapeAndDTypeVisitor {
 
 void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1,
                                              const Variable& var2) {
-  EnforceEqualShapeAndDTypeVisitor visitor{&var1};
+  EnforceShapeAndDTypeEQVisitor visitor{&var1};
   VisitVariable(var2, &visitor);
 }