From 594dc4d8f0a74fe7640d22830cd221b91cbebbb5 Mon Sep 17 00:00:00 2001
From: sneaxiy <sneaxiy@126.com>
Date: Thu, 10 Jan 2019 01:47:22 +0000
Subject: [PATCH] partial gc 1st version test=develop

---
 paddle/fluid/framework/details/CMakeLists.txt |   2 +-
 .../details/eager_deletion_op_handle.cc       |  14 +-
 .../framework/details/eager_deletion_pass.cc  | 125 +++++++++++++++++-
 .../framework/details/reference_count_pass.cc |   9 --
 .../details/reference_count_pass_helper.cc    |  15 ++-
 .../details/reference_count_pass_helper.h     |   8 +-
 python/paddle/fluid/__init__.py               |   1 +
 7 files changed, 158 insertions(+), 16 deletions(-)

diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index 179aa14528..cb34712975 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -54,7 +54,7 @@ cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc
 cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
 cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
         all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
-cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
+cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle proto_desc var_handle)
 cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
 cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
 cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc
index 03fbfd7f24..58cdd65601 100644
--- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc
+++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc
@@ -16,6 +16,7 @@
 #include "paddle/fluid/framework/lod_tensor_array.h"
 #include "paddle/fluid/framework/scope.h"
 #include "paddle/fluid/framework/selected_rows.h"
+#include "paddle/fluid/platform/profiler.h"
 #ifdef PADDLE_WITH_CUDA
 #include "paddle/fluid/platform/cuda_device_guard.h"
 #endif
@@ -45,6 +46,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
     }
   }
 #endif
+  PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
 }
 
 EagerDeletionOpHandle::~EagerDeletionOpHandle() {
@@ -60,7 +62,13 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
 std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
 
 void EagerDeletionOpHandle::RunImpl() {
-  auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
+#ifdef PADDLE_WITH_CUDA
+  platform::RecordEvent record_event(Name(), dev_ctx_);
+#else
+  platform::RecordEvent record_event(Name(), nullptr);
+#endif
+
+  Scope *exec_scope = nullptr;
   std::deque<std::shared_ptr<memory::Allocation>> garbages;
   for (auto &name : var_names_) {
     auto it = ref_cnts_->find(name);
@@ -69,6 +77,10 @@ void EagerDeletionOpHandle::RunImpl() {
       continue;
     }
 
+    if (!exec_scope) {
+      exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
+    }
+
     auto *var = exec_scope->FindVar(name);
     if (var == nullptr) {
       continue;
diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc
index 4e42d0b497..6c8cb66b10 100644
--- a/paddle/fluid/framework/details/eager_deletion_pass.cc
+++ b/paddle/fluid/framework/details/eager_deletion_pass.cc
@@ -12,8 +12,11 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <algorithm>
+#include <functional>
 #include <queue>
 #include <string>
+#include <tuple>
 #include <vector>
 
 #include "paddle/fluid/framework/details/computation_op_handle.h"
@@ -22,10 +25,120 @@
 #include "paddle/fluid/framework/details/multi_devices_helper.h"
 #include "paddle/fluid/framework/ir/graph_helper.h"
 
+DEFINE_double(fraction_of_eager_deletion, 1.0, "Fraction of eager deletion");
+DEFINE_bool(eager_delete_tensor_only, false, "");
+
 namespace paddle {
 namespace framework {
 namespace details {
 
+namespace {  // NOLINT
+using OpToVarNameSetMap =
+    std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
+}  // NOLINT
+
+static bool IsLoDTensor(VarDesc *var) {
+  return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
+}
+
+static int64_t GetNumel(const GraphVars &vars, const std::string &var_name,
+                        size_t scope_idx) {
+  auto *var_desc = TryGetLatestVarDesc(vars[scope_idx].at(var_name));
+  PADDLE_ENFORCE(IsLoDTensor(var_desc));
+  auto dims = var_desc->GetShape();
+  return std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
+                         std::multiplies<int64_t>());
+}
+
+static void SplitIntoLoDTensorAndNonLoDTensorVars(
+    const OpToVarNameSetMap &m, const GraphVars &vars,
+    OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
+  lod_tensors->clear();
+  other_vars->clear();
+
+  for (auto &op_vars_pair : m) {
+    for (auto &var_name : op_vars_pair.second) {
+      auto *var_desc = TryGetLatestVarDesc(
+          vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
+      if (IsLoDTensor(var_desc)) {
+        (*lod_tensors)[op_vars_pair.first].insert(var_name);
+      } else {
+        (*other_vars)[op_vars_pair.first].insert(var_name);
+      }
+    }
+  }
+}
+
+static OpToVarNameSetMap ShrinkGCVars(const OpToVarNameSetMap &m,
+                                      const GraphVars &vars,
+                                      double fraction_of_memory_size,
+                                      bool delete_lod_tensor_only = false) {
+  // Do not perform gc
+  if (fraction_of_memory_size <= 0.0) return {};
+
+  // Perform complete gc
+  if (fraction_of_memory_size >= 1.0) {
+    if (delete_lod_tensor_only) {
+      OpToVarNameSetMap lod_tensors, other_vars;
+      SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
+      return lod_tensors;
+    } else {
+      return m;
+    }
+  }
+
+  // Perform partial gc
+  OpToVarNameSetMap lod_tensors, other_vars;
+  SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
+
+  using TupleType = std::tuple<std::string, ComputationOpHandle *, int64_t>;
+
+  std::unordered_map<size_t, std::vector<TupleType>> place_to_vars;
+  std::unordered_map<size_t, int64_t> total_memory_size;
+  for (auto &op_vars_pair : lod_tensors) {
+    auto scope_idx = op_vars_pair.first->GetScopeIdx();
+    int64_t size = 0;
+    for (auto &var_name : op_vars_pair.second) {
+      auto var_size = GetNumel(vars, var_name, scope_idx);
+      size += std::abs(var_size);
+      place_to_vars[scope_idx].emplace_back(var_name, op_vars_pair.first,
+                                            var_size);
+    }
+    total_memory_size.emplace(scope_idx, size);
+  }
+
+  for (auto &pair : place_to_vars) {
+    std::sort(pair.second.begin(), pair.second.end(),
+              [](const TupleType &t1, const TupleType &t2) {
+                return std::abs(std::get<2>(t1)) > std::abs(std::get<2>(t2));
+              });
+  }
+
+  OpToVarNameSetMap ret;
+  for (auto &pair : place_to_vars) {
+    auto desired_delete_size = static_cast<int64_t>(
+        fraction_of_memory_size * total_memory_size.at(pair.first));
+    int64_t cur_size = 0;
+    for (size_t i = 0; i < pair.second.size() && cur_size < desired_delete_size;
+         ++i) {
+      auto &var_name = std::get<0>(pair.second[i]);
+      auto *op = std::get<1>(pair.second[i]);
+      cur_size += std::get<2>(pair.second[i]);
+      ret[op].insert(var_name);
+    }
+  }
+
+  if (!delete_lod_tensor_only) {
+    for (auto &op_vars_pair : other_vars) {
+      for (auto &var_name : op_vars_pair.second) {
+        ret[op_vars_pair.first].insert(var_name);
+      }
+    }
+  }
+
+  return ret;
+}
+
 std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   auto &ref_cnts =
@@ -43,9 +156,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
 
   // a reverse map of last_live_ops
   //   i.e., last op --> variable names which can be deleted.
-  std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
-      op_vars_map;
-
+  OpToVarNameSetMap op_vars_map;
   for (auto &var_ops_map : last_live_ops) {
     for (auto &var_ops_pair : var_ops_map) {
       const std::string &var_name = var_ops_pair.first;
@@ -55,6 +166,10 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
     }
   }
 
+  op_vars_map =
+      ShrinkGCVars(op_vars_map, vars, FLAGS_fraction_of_eager_deletion,
+                   FLAGS_eager_delete_tensor_only);
+
   for (auto &pair : op_vars_map) {
     auto *op = pair.first;
     auto &var_names = pair.second;
@@ -85,6 +200,10 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
     eager_deletion_op->AddOutput(dummy_leaf);
   }
 
+  VLOG(10) << "FLAGS_fraction_of_eager_deletion = "
+           << FLAGS_fraction_of_eager_deletion;
+  VLOG(10) << "FLAGS_eager_delete_tensor_only = "
+           << FLAGS_eager_delete_tensor_only;
   VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
   return graph;
 }
diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc
index 13a042d8e6..892f638f1f 100644
--- a/paddle/fluid/framework/details/reference_count_pass.cc
+++ b/paddle/fluid/framework/details/reference_count_pass.cc
@@ -189,15 +189,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
   return shrink_func(computation_op);
 }
 
-static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
-  VarDesc *var_desc = nullptr;
-  std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
-    var_desc = var_handle->Node()->Var();
-    return var_desc != nullptr;
-  });
-  return var_desc;
-}
-
 std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
     std::unique_ptr<ir::Graph> graph) const {
   auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.cc b/paddle/fluid/framework/details/reference_count_pass_helper.cc
index 89bd08c2d0..94de0e6ab0 100644
--- a/paddle/fluid/framework/details/reference_count_pass_helper.cc
+++ b/paddle/fluid/framework/details/reference_count_pass_helper.cc
@@ -13,9 +13,22 @@
 // limitations under the License.
 
 #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
+#include "paddle/fluid/framework/details/var_handle.h"
+#include "paddle/fluid/framework/var_desc.h"
 
 namespace paddle {
 namespace framework {
-namespace details {}  // namespace details
+namespace details {
+
+VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
+  VarDesc *var_desc = nullptr;
+  std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
+    var_desc = var_handle->Node()->Var();
+    return var_desc != nullptr;
+  });
+  return var_desc;
+}
+
+}  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.h b/paddle/fluid/framework/details/reference_count_pass_helper.h
index 1c083dbf00..d9e8776d7e 100644
--- a/paddle/fluid/framework/details/reference_count_pass_helper.h
+++ b/paddle/fluid/framework/details/reference_count_pass_helper.h
@@ -25,6 +25,10 @@
 
 namespace paddle {
 namespace framework {
+
+class VarDesc;
+class VarHandle;
+
 namespace details {
 
 class ComputationOpHandle;
@@ -43,9 +47,11 @@ const char kGarbageCollector[] = "garbage_collector";
 const char kAllPlaces[] = "all_places";
 
 using LastLiveOpsOfVars =
-    std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
+    std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
 const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
 
+VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
+
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py
index f9f3807b15..794f583037 100644
--- a/python/paddle/fluid/__init__.py
+++ b/python/paddle/fluid/__init__.py
@@ -127,6 +127,7 @@ def __bootstrap__():
         'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem',
         'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size",
         'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
+        'fraction_of_eager_deletion', 'eager_delete_tensor_only',
         'allocator_strategy', 'reader_queue_speed_test_mode',
         'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
         'enable_parallel_graph'