|
|
|
@ -14,10 +14,31 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/variable_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
|
|
|
|
|
Scope *scope) {
|
|
|
|
|
Scope &local_scope = scope->NewScope();
|
|
|
|
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
|
|
|
|
|
&local_scope;
|
|
|
|
|
|
|
|
|
|
for (auto &info : var_infos) {
|
|
|
|
|
if (scope->FindVar(info.name_) != nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (info.persistable_) { // Persistable
|
|
|
|
|
InitializeVariable(scope->Var(info.name_), info.type_);
|
|
|
|
|
} else {
|
|
|
|
|
InitializeVariable(local_scope.Var(info.name_), info.type_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
|
|
|
|
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
|
|
|
|
@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
|
|
|
|
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
|
|
|
|
|
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList AsyncSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::vector<std::future<FeedFetchList>> run_futures;
|
|
|
|
|
|
|
|
|
|
std::vector<FeedFetchList> fetch_data;
|
|
|
|
|
FeedFetchList ret;
|
|
|
|
|
|
|
|
|
|
fetch_data.reserve(places_.size());
|
|
|
|
|
ret.reserve(fetch_tensors.size());
|
|
|
|
|
exception_holder_.Clear();
|
|
|
|
|
for (auto &node : graphs_[0]->Nodes()) {
|
|
|
|
|
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
|
|
|
|
|
var_infos_.emplace_back();
|
|
|
|
|
var_infos_.back().name_ = node->Var()->Name();
|
|
|
|
|
var_infos_.back().type_ = node->Var()->GetType();
|
|
|
|
|
var_infos_.back().persistable_ = node->Var()->Persistable();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto *scope : local_scopes_) {
|
|
|
|
|
NewTempScopeAndInitVars(var_infos_, scope);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
|
|
|
|
|
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
|
|
|
|
|
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
|
|
|
|
|
for (size_t i = 1; i < places_.size(); ++i) {
|
|
|
|
|
auto call = [this, i]() -> void {
|
|
|
|
|
VLOG(3) << "start off python thread " << i;
|
|
|
|
|
try {
|
|
|
|
|
return executors_[i]->Run(fetch_tensors);
|
|
|
|
|
while (true) {
|
|
|
|
|
executors_[i]->Run({});
|
|
|
|
|
}
|
|
|
|
|
} catch (...) {
|
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
|
VLOG(3) << "get exception type = " << exception_holder_.Type();
|
|
|
|
|
}
|
|
|
|
|
return FeedFetchList();
|
|
|
|
|
VLOG(3) << "thread " << i << " exited!";
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (pool_) {
|
|
|
|
|
run_futures.emplace_back(pool_->enqueue(std::move(call)));
|
|
|
|
|
} else {
|
|
|
|
|
fetch_data.emplace_back(std::move(call()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (pool_) {
|
|
|
|
|
for (auto &f : run_futures) {
|
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
f.wait();
|
|
|
|
|
} else {
|
|
|
|
|
fetch_data.emplace_back(std::move(f.get()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncSSAGraphExecutor::HandleException() {
|
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
for (auto &f : run_futures_) {
|
|
|
|
|
VLOG(3) << "wait future";
|
|
|
|
|
f.wait();
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "caught exception " << exception_holder_.Type()
|
|
|
|
|
<< ", rethrow it";
|
|
|
|
|
run_futures_.clear();
|
|
|
|
|
exception_holder_.ReThrow();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList AsyncSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
// init once
|
|
|
|
|
if (run_futures_.size() == 0 && places_.size() > 1) {
|
|
|
|
|
exception_holder_.Clear();
|
|
|
|
|
StartOffPythonTrainLoop();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (places_.size() == 1) {
|
|
|
|
|
exception_holder_.Clear();
|
|
|
|
|
} else {
|
|
|
|
|
HandleException();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList fetch_data;
|
|
|
|
|
fetch_data.reserve(fetch_tensors.size());
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
fetch_data = executors_[0]->Run(fetch_tensors);
|
|
|
|
|
} catch (...) {
|
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HandleException();
|
|
|
|
|
|
|
|
|
|
FeedFetchList ret;
|
|
|
|
|
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
|
|
|
|
|
std::vector<const LoDTensor *> lodtensor_ptrs;
|
|
|
|
|
lodtensor_ptrs.reserve(local_scopes_.size());
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
|
|
|
|
|
lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx));
|
|
|
|
|
}
|
|
|
|
|
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
|
|
|
|
|
ret.emplace_back();
|
|
|
|
|
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
|
|
|
|
|
}
|
|
|
|
|