|
|
|
@ -43,7 +43,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
|
|
|
|
|
bootstrap_ops_.emplace_back(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
|
|
|
|
|
PrepareAtomicOpDeps();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -52,26 +52,85 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
|
|
|
|
|
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
|
|
|
|
|
op_deps = atomic_op_deps_.get();
|
|
|
|
|
PrepareAtomicOpDeps();
|
|
|
|
|
size_t num_ops = op_deps->size();
|
|
|
|
|
|
|
|
|
|
paddle::framework::FeedFetchList fetches;
|
|
|
|
|
fetches.resize(fetch_tensors.size());
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
|
|
|
|
std::vector<FetchOpHandle *> fetch_ops;
|
|
|
|
|
std::vector<OpHandleBase *> fetch_ops;
|
|
|
|
|
std::vector<OpHandleBase *> ready_fetch_ops;
|
|
|
|
|
exception_.Clear();
|
|
|
|
|
|
|
|
|
|
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
|
|
|
|
|
&fetch_ops, &ready_fetch_ops);
|
|
|
|
|
|
|
|
|
|
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
|
|
|
|
|
// If the num_threads is 1, we can record the order of operator's
|
|
|
|
|
// execution in the first iteration, and in subsequent iterations,
|
|
|
|
|
// run the recorded operators directly. This strategy could make the
|
|
|
|
|
// execution faster.
|
|
|
|
|
VLOG(3) << "Run the traced ops.";
|
|
|
|
|
RunTracedOps(traced_ops_);
|
|
|
|
|
RunTracedOps(fetch_ops);
|
|
|
|
|
if (exception_.IsCaught()) {
|
|
|
|
|
ExecutionFinal(&fetch_ops);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
traced_ops_.clear();
|
|
|
|
|
remaining_ = 0;
|
|
|
|
|
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
|
|
|
|
|
for (auto op : bootstrap_ops_) {
|
|
|
|
|
RunOpAsync(op_deps.get(), op, complete_q);
|
|
|
|
|
}
|
|
|
|
|
for (auto op : ready_fetch_ops) {
|
|
|
|
|
RunOpAsync(op_deps.get(), op, complete_q);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t num_complete = 0;
|
|
|
|
|
while (num_complete != op_deps->size()) {
|
|
|
|
|
size_t num_comp = complete_q->Pop();
|
|
|
|
|
if (num_comp == -1UL) {
|
|
|
|
|
int remaining = 0;
|
|
|
|
|
while (true) {
|
|
|
|
|
remaining = remaining_;
|
|
|
|
|
if (remaining == 0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < remaining; ++i) {
|
|
|
|
|
complete_q->Pop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (exception_.IsCaught()) {
|
|
|
|
|
ExecutionFinal(&fetch_ops);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
num_complete += num_comp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
|
return fetches;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FastThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
|
|
|
|
|
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
|
|
|
|
|
std::vector<OpHandleBase *> *fetch_ops,
|
|
|
|
|
std::vector<OpHandleBase *> *ready_fetch_ops) {
|
|
|
|
|
for (auto &fetch_var_name : fetch_tensors) {
|
|
|
|
|
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
|
|
|
|
|
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
|
|
|
|
|
auto it = var_map.find(fetch_var_name);
|
|
|
|
|
if (it != var_map.end()) {
|
|
|
|
|
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
|
|
|
|
|
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
|
|
|
|
|
auto &var_name = fetch_tensors[i];
|
|
|
|
|
auto fetched_var_it = fetched_vars.find(var_name);
|
|
|
|
|
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
|
|
|
|
|
auto &var_name = fetch_tensors.at(i);
|
|
|
|
|
auto fetched_var_it = fetched_vars->find(var_name);
|
|
|
|
|
PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(),
|
|
|
|
|
"Cannot find fetched variable(%s).(Perhaps the main_program "
|
|
|
|
|
"is not set to ParallelExecutor)",
|
|
|
|
|
var_name);
|
|
|
|
@ -80,8 +139,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
ir::Node *fetch_node =
|
|
|
|
|
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
|
|
|
|
|
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
|
|
|
|
|
fetch_ops.emplace_back(op);
|
|
|
|
|
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
|
|
|
|
|
fetch_ops->emplace_back(op);
|
|
|
|
|
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
|
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
|
|
|
|
@ -94,55 +153,22 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
|
|
|
|
|
int dep = static_cast<int>(op->NotReadyInputSize());
|
|
|
|
|
(*op_deps)[op] = dep;
|
|
|
|
|
if (dep == 0) {
|
|
|
|
|
ready_fetch_ops.emplace_back(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t num_complete = 0;
|
|
|
|
|
remaining_ = 0;
|
|
|
|
|
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
|
|
|
|
|
for (auto op : bootstrap_ops_) {
|
|
|
|
|
RunOpAsync(op_deps.get(), op, complete_q);
|
|
|
|
|
}
|
|
|
|
|
for (auto op : ready_fetch_ops) {
|
|
|
|
|
RunOpAsync(op_deps.get(), op, complete_q);
|
|
|
|
|
}
|
|
|
|
|
while (num_complete != op_deps->size()) {
|
|
|
|
|
size_t num_comp = complete_q->Pop();
|
|
|
|
|
if (num_comp == -1UL) {
|
|
|
|
|
int remaining = 0;
|
|
|
|
|
while (true) {
|
|
|
|
|
remaining = remaining_;
|
|
|
|
|
if (remaining == 0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < remaining; ++i) {
|
|
|
|
|
complete_q->Pop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (exception_.IsCaught()) {
|
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
|
exception_.ReThrow();
|
|
|
|
|
}
|
|
|
|
|
ready_fetch_ops->emplace_back(op);
|
|
|
|
|
}
|
|
|
|
|
num_complete += num_comp;
|
|
|
|
|
}
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
|
return fetches;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FastThreadedSSAGraphExecutor::RunOp(
|
|
|
|
|
OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
|
|
|
|
|
size_t *complete) {
|
|
|
|
|
try {
|
|
|
|
|
RunOpSync(op);
|
|
|
|
|
if (LIKELY(!exception_.IsCaught())) {
|
|
|
|
|
if (LIKELY(!strategy_.dry_run_)) {
|
|
|
|
|
op->Run(strategy_.use_cuda_);
|
|
|
|
|
RecordOps(op);
|
|
|
|
|
}
|
|
|
|
|
++(*complete);
|
|
|
|
|
return true;
|
|
|
|
|
} catch (...) {
|
|
|
|
|
exception_.Catch(std::current_exception());
|
|
|
|
|
} else {
|
|
|
|
|
--remaining_;
|
|
|
|
|
complete_q->Push(-1UL);
|
|
|
|
|
return false;
|
|
|
|
@ -194,6 +220,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
|
|
|
|
|
complete_q->Push(complete);
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
|
|
|
|
|
atomic_op_deps_ = prepare_pool_.enqueue([&] {
|
|
|
|
|
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
|
|
|
|
@ -206,6 +233,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
|
|
|
|
|
|
|
|
|
|
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
|
|
|
|
|
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
|
|
|
|
|
traced_ops_.emplace_back(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FastThreadedSSAGraphExecutor::ExecutionFinal(
|
|
|
|
|
std::vector<OpHandleBase *> *fetch_ops) {
|
|
|
|
|
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
|
|
|
|
|
ClearFetchOp(graph_, fetch_ops);
|
|
|
|
|
exception_.ReThrow();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FastThreadedSSAGraphExecutor::RunTracedOps(
|
|
|
|
|
const std::vector<OpHandleBase *> &traced_ops) {
|
|
|
|
|
for (auto &op : traced_ops) {
|
|
|
|
|
if (exception_.IsCaught()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
RunOpSync(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
|
|
|
|
|
try {
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
|
|
|
|
|
}
|
|
|
|
|
if (LIKELY(!strategy_.dry_run_)) {
|
|
|
|
|
op->Run(strategy_.use_cuda_);
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << op << " " << op->Name() << " Done ";
|
|
|
|
|
} catch (...) {
|
|
|
|
|
exception_.Catch(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|