|
|
|
@ -25,6 +25,45 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
static void CollectUniqueAllocations(
|
|
|
|
|
const Variable &var,
|
|
|
|
|
std::unordered_set<memory::Allocation *> *allocation_set) {
|
|
|
|
|
if (var.IsType<LoDTensor>()) {
|
|
|
|
|
allocation_set->insert(var.Get<LoDTensor>().Holder().get());
|
|
|
|
|
} else if (var.IsType<SelectedRows>()) {
|
|
|
|
|
allocation_set->insert(var.Get<SelectedRows>().value().Holder().get());
|
|
|
|
|
} else if (var.IsType<LoDTensorArray>()) {
|
|
|
|
|
for (auto &t : var.Get<LoDTensorArray>()) {
|
|
|
|
|
allocation_set->insert(t.Holder().get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CollectUniqueAllocations(
|
|
|
|
|
const Scope &scope,
|
|
|
|
|
std::unordered_set<memory::Allocation *> *allocation_set) {
|
|
|
|
|
for (auto &var_name : scope.LocalVarNames()) {
|
|
|
|
|
CollectUniqueAllocations(*scope.FindVar(var_name), allocation_set);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto *kid : scope.kids()) {
|
|
|
|
|
CollectUniqueAllocations(*kid, allocation_set);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static size_t GetScopeVarMemorySize(const Scope &scope) {
|
|
|
|
|
std::unordered_set<memory::Allocation *> allocation_set;
|
|
|
|
|
CollectUniqueAllocations(scope, &allocation_set);
|
|
|
|
|
size_t memory_size = 0;
|
|
|
|
|
for (auto *allocation : allocation_set) {
|
|
|
|
|
if (allocation) {
|
|
|
|
|
memory_size += allocation->size();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return memory_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
|
|
|
|
|
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
|
|
|
|
|
std::vector<Scope *> local_exec_scopes, std::vector<VariableInfo> var_infos,
|
|
|
|
@ -55,10 +94,27 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
|
|
|
|
|
eptr = std::current_exception();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(5)) {
|
|
|
|
|
for (auto *scope : local_exec_scopes_) {
|
|
|
|
|
VLOG(5) << "Left "
|
|
|
|
|
<< string::HumanReadableSize(GetScopeVarMemorySize(*scope))
|
|
|
|
|
<< " on scope " << scope << " before deleting";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++drop_scope_counter_;
|
|
|
|
|
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
|
|
|
|
|
DropLocalExeScopes();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(5)) {
|
|
|
|
|
for (auto *scope : local_exec_scopes_) {
|
|
|
|
|
VLOG(5) << "Left "
|
|
|
|
|
<< string::HumanReadableSize(GetScopeVarMemorySize(*scope))
|
|
|
|
|
<< " on scope " << scope << " after deleting";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (eptr) {
|
|
|
|
|
std::rethrow_exception(eptr);
|
|
|
|
|
} else {
|
|
|
|
|