|
|
|
|
@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
places_(places),
|
|
|
|
|
fetch_ctxs_(places),
|
|
|
|
|
use_event_(use_event) {}
|
|
|
|
|
use_event_(use_event),
|
|
|
|
|
running_ops_(0) {}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunDelayedOps(
|
|
|
|
|
const std::unordered_set<OpHandleBase *> &delayed_ops) {
|
|
|
|
|
for (auto op : delayed_ops) {
|
|
|
|
|
op->Run(use_event_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
|
|
|
|
std::unordered_set<VarHandleBase *> pending_vars;
|
|
|
|
|
|
|
|
|
|
BlockingQueue<VarHandleBase *> ready_vars;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<OpHandleBase *> ready_ops;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<OpHandleBase *> delayed_ops;
|
|
|
|
|
std::unordered_set<OpHandleBase *> after_delayed_ops;
|
|
|
|
|
std::unordered_set<VarHandleBase *> delayed_vars;
|
|
|
|
|
|
|
|
|
|
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
|
|
|
|
|
pending_vars.insert(&var);
|
|
|
|
|
if (var.generated_op_ == nullptr) {
|
|
|
|
|
@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
auto run_all_ready_ops = [&] {
|
|
|
|
|
for (auto *op : ready_ops) {
|
|
|
|
|
RunOp(ready_vars, op);
|
|
|
|
|
if (op->IsDelayedOp()) {
|
|
|
|
|
delayed_ops.insert(op);
|
|
|
|
|
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
|
|
|
|
|
ready_vars.Extend(op->outputs_);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
running_ops_++;
|
|
|
|
|
RunOp(&ready_vars, op);
|
|
|
|
|
}
|
|
|
|
|
ready_ops.clear();
|
|
|
|
|
};
|
|
|
|
|
@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
// 2. Find ready variable
|
|
|
|
|
bool timeout;
|
|
|
|
|
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout);
|
|
|
|
|
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
|
|
|
|
|
|
|
|
|
|
if (timeout) {
|
|
|
|
|
if (exception_) {
|
|
|
|
|
@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
auto &deps = pending_ops[op];
|
|
|
|
|
--deps;
|
|
|
|
|
if (deps == 0) {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
|
|
|
|
|
after_delayed_ops.insert(op);
|
|
|
|
|
} else {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
|
|
|
|
|
RunDelayedOps(delayed_ops);
|
|
|
|
|
delayed_ops.clear();
|
|
|
|
|
for (auto *op : after_delayed_ops) {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
after_delayed_ops.clear();
|
|
|
|
|
}
|
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++computation_count_;
|
|
|
|
|
|
|
|
|
|
auto sync_computation = [&] {
|
|
|
|
|
@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunOp(
|
|
|
|
|
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
|
|
|
|
|
auto op_run = [&ready_var_q, op, this] {
|
|
|
|
|
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
|
|
|
|
|
auto op_run = [ready_var_q, op, this] {
|
|
|
|
|
try {
|
|
|
|
|
VLOG(10) << op->Name() << " : " << op->DebugString();
|
|
|
|
|
op->Run(use_event_);
|
|
|
|
|
ready_var_q.Extend(op->outputs_);
|
|
|
|
|
running_ops_--;
|
|
|
|
|
ready_var_q->Extend(op->outputs_);
|
|
|
|
|
} catch (platform::EnforceNotMet ex) {
|
|
|
|
|
exception_.reset(new platform::EnforceNotMet(ex));
|
|
|
|
|
} catch (...) {
|
|
|
|
|
|