|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
|
|
|
|
#include "paddle/fluid/framework/scope.h"
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
@ -30,14 +31,13 @@ namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
EagerDeletionOpHandle::EagerDeletionOpHandle(
|
|
|
|
|
ir::Node *node, const Scope *scope, const platform::Place &place,
|
|
|
|
|
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
|
|
|
|
|
ir::AtomicReferenceCountMap *ref_cnts)
|
|
|
|
|
ir::Node *node, Scope *scope, const platform::Place &place,
|
|
|
|
|
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
|
|
|
|
|
: OpHandleBase(node),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
var_names_(var_names.begin(), var_names.end()),
|
|
|
|
|
gc_(gc),
|
|
|
|
|
ref_cnts_(ref_cnts) {
|
|
|
|
|
place_(place),
|
|
|
|
|
var_infos_(vars.begin(), vars.end()),
|
|
|
|
|
gc_(gc) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
|
|
|
|
@ -50,7 +50,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty(), "Var names cannot be empty");
|
|
|
|
|
for (auto *var : var_infos_) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
|
|
|
|
@ -63,30 +66,43 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void EagerDeletionOpHandle::InitCUDA() {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
int dev_id =
|
|
|
|
|
boost::get<platform::CUDAPlace>(dev_ctxes_.begin()->first).device;
|
|
|
|
|
events_[dev_id] = nullptr;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void EagerDeletionOpHandle::CallOnce() {
|
|
|
|
|
PADDLE_ENFORCE(vars_.empty(), "vars_ must be initialized here");
|
|
|
|
|
Scope *exec_scope = local_exec_scopes_[0];
|
|
|
|
|
for (auto *var_info : var_infos_) {
|
|
|
|
|
auto *var = exec_scope->FindVar(var_info->Name());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Variable %s should not be nullptr",
|
|
|
|
|
var_info->Name());
|
|
|
|
|
vars_.emplace_back(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
|
|
|
|
|
|
|
|
|
|
void EagerDeletionOpHandle::RunImpl() {
|
|
|
|
|
if (vars_.size() != var_infos_.size()) {
|
|
|
|
|
CallOnce();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(Name());
|
|
|
|
|
Scope *exec_scope = nullptr;
|
|
|
|
|
std::deque<std::shared_ptr<memory::Allocation>> garbages;
|
|
|
|
|
for (auto &name : var_names_) {
|
|
|
|
|
auto it = ref_cnts_->find(name);
|
|
|
|
|
// Reference count has not decreased to 0
|
|
|
|
|
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
|
|
|
|
|
for (size_t i = 0; i < var_infos_.size(); ++i) {
|
|
|
|
|
auto *var_info = var_infos_[i];
|
|
|
|
|
if (var_info->IsSkipped() || !var_info->DecreaseRefCnt()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!exec_scope) {
|
|
|
|
|
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Var not found
|
|
|
|
|
auto *var = exec_scope->FindVar(name);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
VLOG(2) << "Erase variable " << var_info->Name() << " on " << place_;
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "Erase variable " << name;
|
|
|
|
|
Variable *var = vars_[i];
|
|
|
|
|
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
|
|
|
|
@ -100,7 +116,7 @@ void EagerDeletionOpHandle::RunImpl() {
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Type %s of %s is not supported eager deletion",
|
|
|
|
|
framework::ToTypeName(var->Type()), name);
|
|
|
|
|
framework::ToTypeName(var->Type()), var_info->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|