|
|
|
@ -35,11 +35,17 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
|
|
|
|
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
|
|
|
|
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
|
|
|
|
|
std::unordered_set<VarHandleBase *> pending_vars;
|
|
|
|
|
|
|
|
|
|
BlockingQueue<VarHandleBase *> ready_vars;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<OpHandleBase *> ready_ops;
|
|
|
|
|
|
|
|
|
|
auto InsertPendingVar = [&pending_vars](VarHandleBase &var) {
|
|
|
|
|
pending_vars[&var] = var.generated_op_ == nullptr;
|
|
|
|
|
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
|
|
|
|
|
pending_vars.insert(&var);
|
|
|
|
|
if (var.generated_op_ == nullptr) {
|
|
|
|
|
ready_vars.Push(&var);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
|
|
|
|
@ -101,7 +107,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
auto run_all_ready_ops = [&] {
|
|
|
|
|
for (auto *op : ready_ops) {
|
|
|
|
|
RunOp(pending_vars, op);
|
|
|
|
|
RunOp(ready_vars, op);
|
|
|
|
|
}
|
|
|
|
|
ready_ops.clear();
|
|
|
|
|
};
|
|
|
|
@ -118,29 +124,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
run_all_ready_ops();
|
|
|
|
|
|
|
|
|
|
// 2. Find ready variable
|
|
|
|
|
VarHandleBase *ready_var = nullptr;
|
|
|
|
|
for (auto &pair : pending_vars) {
|
|
|
|
|
if (pair.second.load(std::memory_order_acquire)) {
|
|
|
|
|
ready_var = pair.first;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if there is no variable ready
|
|
|
|
|
if (ready_var == nullptr) {
|
|
|
|
|
// FIXME use conditional var instead of busy wait.
|
|
|
|
|
// if there is an exception, throw it
|
|
|
|
|
if (exception_) {
|
|
|
|
|
throw * exception_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "=============================";
|
|
|
|
|
for (auto &op : pending_ops) {
|
|
|
|
|
VLOG(10) << op.first->DebugString();
|
|
|
|
|
}
|
|
|
|
|
// keep waiting the ready variables
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
VarHandleBase *ready_var = ready_vars.Pop();
|
|
|
|
|
|
|
|
|
|
// 3. Remove the dependency of ready_var.
|
|
|
|
|
// Find the ready_ops after the ready_var.
|
|
|
|
@ -189,23 +173,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunOp(
|
|
|
|
|
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
|
|
|
|
|
details::OpHandleBase *op) {
|
|
|
|
|
std::vector<std::atomic<bool> *> *ready_buffer =
|
|
|
|
|
new std::vector<std::atomic<bool> *>();
|
|
|
|
|
for (auto *var : op->outputs_) {
|
|
|
|
|
ready_buffer->emplace_back(&pending_vars[var]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto op_run = [ready_buffer, op, this] {
|
|
|
|
|
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
|
|
|
|
|
auto op_run = [&ready_var_q, op, this] {
|
|
|
|
|
try {
|
|
|
|
|
VLOG(10) << op->Name() << " : " << op->DebugString();
|
|
|
|
|
op->Run(use_event_);
|
|
|
|
|
|
|
|
|
|
for (auto *ready : *ready_buffer) {
|
|
|
|
|
ready->store(true, std::memory_order_release);
|
|
|
|
|
for (auto &each : op->outputs_) {
|
|
|
|
|
ready_var_q.Push(each);
|
|
|
|
|
}
|
|
|
|
|
delete ready_buffer;
|
|
|
|
|
} catch (platform::EnforceNotMet ex) {
|
|
|
|
|
exception_.reset(new platform::EnforceNotMet(ex));
|
|
|
|
|
} catch (...) {
|
|
|
|
|