|
|
|
@ -50,7 +50,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
// together since we currently cannot overlap computation and memcpy streams.
|
|
|
|
|
// Should revisit it if overlapping is available.
|
|
|
|
|
std::unordered_set<OpHandleBase *> delayed_ops;
|
|
|
|
|
std::unordered_set<OpHandleBase *> after_delayed_ops;
|
|
|
|
|
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
|
|
|
|
|
std::unordered_set<VarHandleBase *> delayed_vars;
|
|
|
|
|
|
|
|
|
|
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
|
|
|
|
@ -119,7 +119,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
auto run_all_ready_ops = [&] {
|
|
|
|
|
for (auto *op : ready_ops) {
|
|
|
|
|
if (op->IsDelayedOp()) {
|
|
|
|
|
if (op->IsMultiDeviceTransfer()) {
|
|
|
|
|
delayed_ops.insert(op);
|
|
|
|
|
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
|
|
|
|
|
ready_vars.Extend(op->outputs_);
|
|
|
|
@ -162,20 +162,22 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
--deps;
|
|
|
|
|
if (deps == 0) {
|
|
|
|
|
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
|
|
|
|
|
after_delayed_ops.insert(op);
|
|
|
|
|
blocked_by_delayed_ops.insert(op);
|
|
|
|
|
} else {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// When there are no other ops to schedule, schedule buffered delayed
|
|
|
|
|
// ops and unblock other ops.
|
|
|
|
|
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
|
|
|
|
|
RunDelayedOps(delayed_ops);
|
|
|
|
|
delayed_ops.clear();
|
|
|
|
|
for (auto *op : after_delayed_ops) {
|
|
|
|
|
for (auto *op : blocked_by_delayed_ops) {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
after_delayed_ops.clear();
|
|
|
|
|
blocked_by_delayed_ops.clear();
|
|
|
|
|
}
|
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
|
}
|
|
|
|
|