|
|
|
@ -36,26 +36,10 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
|
|
|
|
|
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
if (drop_scope_counter_ == 0) {
|
|
|
|
|
// Create local scopes.
|
|
|
|
|
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
|
|
|
|
|
auto &scope = *it;
|
|
|
|
|
Scope &local_scope = scope->NewScope();
|
|
|
|
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
|
|
|
|
|
&local_scope;
|
|
|
|
|
|
|
|
|
|
for (auto &info : var_infos_) {
|
|
|
|
|
if (scope->FindVar(info.name_) != nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (info.persistable_) { // Persistable
|
|
|
|
|
InitializeVariable(scope->Var(info.name_), info.type_);
|
|
|
|
|
} else {
|
|
|
|
|
InitializeVariable(local_scope.Var(info.name_), info.type_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
platform::RecordEvent e("InitLocalExeScopes");
|
|
|
|
|
PrepareLocalExeScopes();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<framework::LoDTensor> fetch_data;
|
|
|
|
|
std::exception_ptr eptr = nullptr;
|
|
|
|
|
try {
|
|
|
|
@ -64,9 +48,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
|
|
|
|
|
eptr = std::current_exception();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun");
|
|
|
|
|
++drop_scope_counter_;
|
|
|
|
|
|
|
|
|
|
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
|
|
|
|
|
DropLocalExeScopes();
|
|
|
|
|
}
|
|
|
|
@ -78,11 +60,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
|
|
|
|
|
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
|
|
|
|
|
drop_scope_counter_ = 0;
|
|
|
|
|
for (auto p : places_) {
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p)->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &scope : local_scopes_) {
|
|
|
|
|
auto &local_scope =
|
|
|
|
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
|
|
|
|
@ -91,6 +73,26 @@ void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
|
|
|
|
|
// Create local scopes.
|
|
|
|
|
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
|
|
|
|
|
auto &scope = *it;
|
|
|
|
|
Scope &local_scope = scope->NewScope();
|
|
|
|
|
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
|
|
|
|
|
|
|
|
|
|
for (auto &info : var_infos_) {
|
|
|
|
|
if (scope->FindVar(info.name_) != nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (info.persistable_) { // Persistable
|
|
|
|
|
InitializeVariable(scope->Var(info.name_), info.type_);
|
|
|
|
|
} else {
|
|
|
|
|
InitializeVariable(local_scope.Var(info.name_), info.type_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
|
|
|
|
|
return drop_scope_counter_ == 0;
|
|
|
|
|
}
|
|
|
|
|