From e7684911fd7680a2c5576da0833b7558a4ff9ba0 Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Wed, 11 Apr 2018 16:32:31 +0800
Subject: [PATCH] add gather op handle

---
 paddle/fluid/framework/details/CMakeLists.txt |  14 +-
 .../framework/details/broadcast_op_handle.cc  |  39 +--
 .../framework/details/broadcast_op_handle.h   |   1 -
 .../details/broadcast_op_handle_test.cc       |   6 +-
 .../framework/details/gather_op_handle.cc     | 121 ++++++++++
 .../framework/details/gather_op_handle.h      |  52 ++++
 .../details/gather_op_handle_test.cc          | 227 ++++++++++++++++++
 7 files changed, 432 insertions(+), 28 deletions(-)
 create mode 100644 paddle/fluid/framework/details/gather_op_handle.cc
 create mode 100644 paddle/fluid/framework/details/gather_op_handle.h
 create mode 100644 paddle/fluid/framework/details/gather_op_handle_test.cc

diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index 2a87f02bd5..3644ed9cb7 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -5,22 +5,22 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
 if(WITH_GPU)
     nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
         dynload_cuda)
-    nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
-endif()
-cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
-cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
-cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
-
-if(WITH_GPU)
     set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
 else()
     set(multi_devices_graph_builder_deps)
 endif()
+cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
+cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
+cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
 cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
             scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
 cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
 cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
         simple_threadpool device_context)
+cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
+cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
 
 cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
         device_context broadcast_op_handle)
+cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
+        device_context gather_op_handle)
diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc
index 2c99a347bf..7cd13a50f5 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle.cc
@@ -18,7 +18,7 @@ namespace paddle {
 namespace framework {
 namespace details {
 
-Tensor *GetTensorFromVar(Variable *in_var) {
+static Tensor *GetTensorFromVar(Variable *in_var) {
   if (in_var->IsType<LoDTensor>()) {
     return in_var->GetMutable<LoDTensor>();
   } else if (in_var->IsType<SelectedRows>()) {
@@ -52,29 +52,34 @@ void BroadcastOpHandle::RunImpl() {
     auto &out_p = out_handle->place_;
 
     auto out_scope_idx = out_handle->scope_idx_;
-    PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(), "");
+    PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(),
+                      "%s is not the the local_scopes ", out_handle->name_);
     auto *s = local_scopes_[out_scope_idx];
     auto out_var = s->FindVar(out_handle->name_);
 
-    PADDLE_ENFORCE_EQ(out_var->Type(), in_var->Type(), "");
+    PADDLE_ENFORCE_EQ(
+        out_var->Type(), in_var->Type(),
+        "The type of input and output is not equal. (%s_%d vs %s_%d)",
+        out_handle->name_, out_handle->scope_idx_, in_var_handle->name_,
+        in_var_handle->scope_idx_);
 
     if (in_var->IsType<framework::SelectedRows>()) {
-      auto in_sr = in_var->GetMutable<framework::SelectedRows>();
-      auto out = out_var->GetMutable<framework::SelectedRows>();
-      if (in_sr == out) continue;
-      out->set_height(in_sr->height());
-      out->set_rows(in_sr->rows());
-      out->mutable_value()->Resize(in_sr->value().dims());
-      out->mutable_value()->mutable_data(out_p, in_sr->value().type());
+      auto &in_sr = in_var->Get<framework::SelectedRows>();
+      auto out_sr = out_var->GetMutable<framework::SelectedRows>();
+      if (&in_sr == out_sr) continue;
+      out_sr->set_height(in_sr.height());
+      out_sr->set_rows(in_sr.rows());
+      out_sr->mutable_value()->Resize(in_sr.value().dims());
+      out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
     } else if (in_var->IsType<framework::LoDTensor>()) {
-      auto in_lod = in_var->GetMutable<framework::LoDTensor>();
-      auto out = out_var->GetMutable<framework::LoDTensor>();
-      if (in_lod == out) continue;
-      out->set_lod(in_lod->lod());
-      out->Resize(in_lod->dims());
-      out->mutable_data(out_p, in_lod->type());
+      auto in_lod = in_var->Get<framework::LoDTensor>();
+      auto out_lod = out_var->GetMutable<framework::LoDTensor>();
+      if (&in_lod == out_lod) continue;
+      out_lod->set_lod(in_lod.lod());
+      out_lod->Resize(in_lod.dims());
+      out_lod->mutable_data(out_p, in_lod.type());
     } else {
-      PADDLE_THROW("Var should be LoDTensor or SelectedRows");
+      PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
     }
 
     Tensor *out_tensor = GetTensorFromVar(out_var);
diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h
index 06ec164ce0..74c0a6a098 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle.h
+++ b/paddle/fluid/framework/details/broadcast_op_handle.h
@@ -35,7 +35,6 @@ namespace details {
 struct BroadcastOpHandle : public OpHandleBase {
   const std::vector<Scope *> &local_scopes_;
   const std::vector<platform::Place> &places_;
-  //  const platform::ContextMap &ctxs_;
 
   BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
                     const std::vector<platform::Place> &places);
diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc
index d03115f0be..29cf120c76 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc
@@ -84,7 +84,7 @@ class BroadcastTester : public ::testing::Test {
       bc_op_handle_->AddOutput(out_var_handle);
     }
   }
-  void BroadcastDestroy() {
+  void BroadcastOpDestroy() {
     for (auto in : bc_op_handle_->inputs_) {
       delete in;
     }
@@ -139,7 +139,7 @@ class BroadcastTester : public ::testing::Test {
       }
     }
 
-    BroadcastDestroy();
+    BroadcastOpDestroy();
   }
 
   void TestBroadcastSelectedRows() {
@@ -188,7 +188,7 @@ class BroadcastTester : public ::testing::Test {
       }
     }
 
-    BroadcastDestroy();
+    BroadcastOpDestroy();
   }
 
  public:
diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc
new file mode 100644
index 0000000000..9407868372
--- /dev/null
+++ b/paddle/fluid/framework/details/gather_op_handle.cc
@@ -0,0 +1,121 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/details/gather_op_handle.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+static Tensor *GetTensorFromVar(Variable *in_var) {
+  if (in_var->IsType<LoDTensor>()) {
+    return in_var->GetMutable<LoDTensor>();
+  } else if (in_var->IsType<SelectedRows>()) {
+    return in_var->GetMutable<SelectedRows>()->mutable_value();
+  } else {
+    PADDLE_THROW("Var should be LoDTensor or SelectedRows");
+  }
+  return nullptr;
+}
+GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
+                               const std::vector<platform::Place> &places)
+    : local_scopes_(local_scopes), places_(places) {}
+
+void GatherOpHandle::RunImpl() {
+  PADDLE_ENFORCE_EQ(this->inputs_.size(), places_.size());
+  PADDLE_ENFORCE_EQ(this->outputs_.size(), 1);
+
+  // Wait input done, this Wait is asynchronous operation
+  for (auto *in : inputs_) {
+    if (inputs_[0]->generated_op_) {
+      auto &p = static_cast<VarHandle *>(in)->place_;
+      in->generated_op_->Wait(dev_ctxes_[p]);
+    }
+  }
+  auto in_0_handle = static_cast<VarHandle *>(inputs_[0]);
+  auto pre_in_var =
+      local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
+
+  std::vector<int64_t> out_rows;
+  std::vector<Tensor *> in_tensors;
+  std::vector<platform::Place> in_places;
+
+  // gather the inputs
+  for (auto *in : inputs_) {
+    auto in_handle = static_cast<VarHandle *>(in);
+    auto in_p = in_handle->place_;
+    in_places.push_back(in_p);
+    PADDLE_ENFORCE_LT(in_handle->scope_idx_, local_scopes_.size(),
+                      "%s is not the the local_scopes ", in_handle->name_);
+
+    auto *s = local_scopes_[in_handle->scope_idx_];
+    auto in_var = s->FindVar(in_handle->name_);
+    PADDLE_ENFORCE_EQ(in_var->Type(), pre_in_var->Type(),
+                      "The type of input is not consistent.");
+
+    if (in_var->IsType<framework::SelectedRows>()) {
+      auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
+      auto &in_sr = in_var->Get<framework::SelectedRows>();
+      auto in_sr_rows = in_sr.rows();
+      out_rows.insert(out_rows.begin(), in_sr_rows.begin(), in_sr_rows.end());
+      PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(), "");
+      PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), "");
+    } else if (in_var->IsType<framework::LoDTensor>()) {
+      auto &pre_in = pre_in_var->Get<framework::LoDTensor>();
+      auto &in_lodtensor = in_var->Get<framework::LoDTensor>();
+      PADDLE_ENFORCE_EQ(in_lodtensor.lod(), pre_in.lod());
+      PADDLE_ENFORCE_EQ(in_lodtensor.dims(), pre_in.dims());
+    } else {
+      PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
+    }
+    in_tensors.push_back(GetTensorFromVar(in_var));
+    pre_in_var = in_var;
+  }
+
+  // write the output
+  auto out_handle = static_cast<VarHandle *>(this->outputs_[0]);
+  auto &out_place = out_handle->place_;
+  auto out_scope_idx = out_handle->scope_idx_;
+  auto out_var = local_scopes_[out_scope_idx]->FindVar(out_handle->name_);
+
+  if (pre_in_var->IsType<framework::SelectedRows>()) {
+    auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
+    auto out = out_var->GetMutable<framework::SelectedRows>();
+    out->set_height(pre_in.height());
+    out->set_rows(out_rows);
+    size_t rows = out_rows.size();
+    DDim out_dim = pre_in.GetCompleteDims();
+    out_dim[0] = static_cast<int64_t>(rows);
+    out->mutable_value()->Resize(out_dim);
+    out->mutable_value()->mutable_data(out_place, pre_in.value().type());
+    auto out_tensor = out->mutable_value();
+    // copy
+    int s = 0, e = 0;
+    for (size_t j = 0; j < in_tensors.size(); ++j) {
+      e += in_tensors[j]->dims()[0];
+      auto sub_out = out_tensor->Slice(s, e);
+      paddle::framework::TensorCopy(*(in_tensors[j]), out_place,
+                                    *(dev_ctxes_[in_places[j]]), &sub_out);
+      s = e;
+    }
+  } else if (pre_in_var->IsType<framework::LoDTensor>()) {
+  } else {
+    PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
+  }
+}
+
+std::string GatherOpHandle::Name() const { return "broadcast"; }
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h
new file mode 100644
index 0000000000..48e1db227b
--- /dev/null
+++ b/paddle/fluid/framework/details/gather_op_handle.h
@@ -0,0 +1,52 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "paddle/fluid/framework/details/op_handle_base.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/framework/selected_rows.h"
+#include "paddle/fluid/platform/device_context.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+/*
+ * Broadcast the input to all scope.
+ *
+ */
+struct GatherOpHandle : public OpHandleBase {
+  const std::vector<Scope *> &local_scopes_;
+  const std::vector<platform::Place> &places_;
+
+  GatherOpHandle(const std::vector<Scope *> &local_scopes,
+                 const std::vector<platform::Place> &places);
+
+  std::string Name() const override;
+
+  bool IsMultiDeviceTransfer() override { return false; };
+
+ protected:
+  void RunImpl() override;
+};
+
+}  // namespace details
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc
new file mode 100644
index 0000000000..a029a2d266
--- /dev/null
+++ b/paddle/fluid/framework/details/gather_op_handle_test.cc
@@ -0,0 +1,227 @@
+//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/details/gather_op_handle.h"
+#include "gtest/gtest.h"
+
+#include "paddle/fluid/platform/device_context.h"
+
+namespace f = paddle::framework;
+namespace p = paddle::platform;
+
+// test data amount
+const f::DDim kDims = {20, 20};
+
+class GatherTester : public ::testing::Test {
+ public:
+  void InitCtx(bool use_gpu) {
+    if (use_gpu) {
+#ifdef PADDLE_WITH_CUDA
+      int count = p::GetCUDADeviceCount();
+      if (count <= 1) {
+        LOG(WARNING) << "Cannot test multi-gpu Gather, because the CUDA "
+                        "device count is "
+                     << count;
+        exit(0);
+      }
+      for (int i = 0; i < count; ++i) {
+        auto p = p::CUDAPlace(i);
+        gpu_list_.push_back(p);
+        ctxs_.emplace_back(new p::CUDADeviceContext(p));
+      }
+#else
+      PADDLE_THROW("CUDA is not support.");
+#endif
+    } else {
+      int count = 8;
+      for (int i = 0; i < count; ++i) {
+        auto p = p::CPUPlace();
+        gpu_list_.push_back(p);
+        ctxs_.emplace_back(new p::CPUDeviceContext(p));
+      }
+    }
+  }
+
+  template <class T>
+  void InitGatherOp(int input_scope_idx) {
+    for (size_t j = 0; j < gpu_list_.size(); ++j) {
+      local_scope_.push_back(&g_scope_.NewScope());
+      auto* out_var = local_scope_[j]->Var("input");
+      out_var->GetMutable<T>();
+    }
+    auto* in_var = local_scope_[input_scope_idx]->Var("out");
+    in_var->GetMutable<T>();
+
+    gather_op_handle_ = new f::details::GatherOpHandle(local_scope_, gpu_list_);
+
+    f::details::VarHandle* out_var_handle = new f::details::VarHandle();
+    out_var_handle->place_ = gpu_list_[input_scope_idx];
+    out_var_handle->name_ = "out";
+    out_var_handle->version_ = 2;
+    out_var_handle->scope_idx_ = input_scope_idx;
+    out_var_handle->generated_op_ = gather_op_handle_;
+    gather_op_handle_->AddOutput(out_var_handle);
+
+    for (size_t j = 0; j < gpu_list_.size(); ++j) {
+      gather_op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j];
+      f::details::VarHandle* in_var_handle = new f::details::VarHandle();
+      in_var_handle->place_ = gpu_list_[j];
+      in_var_handle->name_ = "input";
+      in_var_handle->version_ = 1;
+      in_var_handle->scope_idx_ = j;
+      in_var_handle->generated_op_ = nullptr;
+      gather_op_handle_->AddInput(in_var_handle);
+    }
+  }
+  void GatherOpDestroy() {
+    for (auto in : gather_op_handle_->inputs_) {
+      delete in;
+    }
+    for (auto out : gather_op_handle_->outputs_) {
+      delete out;
+    }
+    delete gather_op_handle_;
+    for (size_t j = 0; j < ctxs_.size(); ++j) {
+      delete ctxs_[j];
+    }
+  }
+
+  void WaitAll() {
+    for (size_t j = 0; j < ctxs_.size(); ++j) {
+      ctxs_[j]->Wait();
+    }
+  }
+
+  void TestGatherLodTensor() {
+    //    int input_scope_idx = 0;
+    //    InitGatherOp<f::LoDTensor>(input_scope_idx);
+    //
+    //    auto in_var = local_scope_[input_scope_idx]->Var("input");
+    //    auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
+    //    in_lod_tensor->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
+    //
+    //    std::vector<float> send_vector(f::product(kDims), input_scope_idx +
+    //    12);
+    //    for (size_t k = 0; k < send_vector.size(); ++k) {
+    //      send_vector[k] = k;
+    //    }
+    //    f::LoD lod{{0, 10, 20}};
+    //    paddle::framework::TensorFromVector<float>(
+    //        send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor);
+    //    in_lod_tensor->set_lod(lod);
+    //
+    //    gather_op_handle_->Run(false);
+    //
+    //    WaitAll();
+    //
+    //    p::CPUPlace cpu_place;
+    //    for (size_t j = 0; j < gpu_list_.size(); ++j) {
+    //      auto out_var = local_scope_[j]->Var("out");
+    //      auto out_tensor = out_var->Get<f::LoDTensor>();
+    //      PADDLE_ENFORCE_EQ(out_tensor.lod(), lod, "lod is not equal.");
+    //
+    //      f::Tensor result_tensor;
+    //      f::TensorCopy(out_tensor, cpu_place, *(ctxs_[j]), &result_tensor);
+    //      float* ct = result_tensor.mutable_data<float>(cpu_place);
+    //
+    //      for (int64_t j = 0; j < f::product(kDims); ++j) {
+    //        ASSERT_NEAR(ct[j], send_vector[j], 1e-5);
+    //      }
+    //    }
+    //
+    //    GatherOpDestroy();
+  }
+
+  void TestGatherSelectedRows() {
+    int output_scope_idx = 0;
+    InitGatherOp<f::SelectedRows>(output_scope_idx);
+
+    int height = kDims[0] * 2;
+    std::vector<int64_t> rows{0, 1, 2, 3, 3, 0, 14, 7, 3, 1,
+                              2, 4, 6, 3, 1, 1, 1,  1, 3, 7};
+    std::vector<float> send_vector(f::product(kDims));
+    for (size_t k = 0; k < send_vector.size(); ++k) {
+      send_vector[k] = k;
+    }
+
+    for (size_t input_scope_idx = 0; input_scope_idx < gpu_list_.size();
+         ++input_scope_idx) {
+      auto in_var = local_scope_[input_scope_idx]->Var("input");
+      auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
+      auto value = in_selected_rows->mutable_value();
+      value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
+
+      in_selected_rows->set_height(height);
+      in_selected_rows->set_rows(rows);
+
+      paddle::framework::TensorFromVector<float>(
+          send_vector, *(ctxs_[input_scope_idx]), value);
+      value->Resize(kDims);
+    }
+
+    gather_op_handle_->Run(false);
+
+    WaitAll();
+
+    p::CPUPlace cpu_place;
+
+    auto out_var = local_scope_[output_scope_idx]->Var("out");
+    auto& out_select_rows = out_var->Get<f::SelectedRows>();
+    auto rt = out_select_rows.value();
+
+    PADDLE_ENFORCE_EQ(out_select_rows.height(), height, "height is not equal.");
+    for (size_t k = 0; k < out_select_rows.rows().size(); ++k) {
+      PADDLE_ENFORCE_EQ(out_select_rows.rows()[k], rows[k % rows.size()]);
+    }
+
+    f::Tensor result_tensor;
+    f::TensorCopy(rt, cpu_place, *(ctxs_[output_scope_idx]), &result_tensor);
+    float* ct = result_tensor.data<float>();
+
+    for (int64_t j = 0; j < f::product(kDims); ++j) {
+      ASSERT_NEAR(ct[j], send_vector[j % send_vector.size()], 1e-5);
+    }
+
+    GatherOpDestroy();
+  }
+
+ public:
+  f::Scope g_scope_;
+  std::vector<p::DeviceContext*> ctxs_;
+  std::vector<f::Scope*> local_scope_;
+  std::vector<p::Place> gpu_list_;
+  f::details::GatherOpHandle* gather_op_handle_;
+};
+
+// TEST_F(GatherTester, TestCPUGatherTestLodTensor) {
+//  InitCtx(false);
+//  TestGatherLodTensor();
+//}
+
+TEST_F(GatherTester, TestCPUGatherTestSelectedRows) {
+  InitCtx(false);
+  TestGatherSelectedRows();
+}
+
+#ifdef PADDLE_WITH_CUDA
+// TEST_F(GatherTester, TestGPUGatherTestLodTensor) {
+//  InitCtx(true);
+//  TestGatherLodTensor();
+//}
+
+TEST_F(GatherTester, TestGPUGatherTestSelectedRows) {
+  InitCtx(true);
+  TestGatherSelectedRows();
+}
+#endif