|
|
|
@ -33,13 +33,6 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
|
|
|
|
running_ops_(0),
|
|
|
|
|
allow_op_delay_(allow_op_delay) {}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunDelayedOps(
|
|
|
|
|
const std::unordered_set<OpHandleBase *> &delayed_ops) {
|
|
|
|
|
for (auto op : delayed_ops) {
|
|
|
|
|
op->Run(use_event_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
|
|
|
@ -51,8 +44,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
// together since we currently cannot overlap computation and memcpy streams.
|
|
|
|
|
// Should revisit it if overlapping is available.
|
|
|
|
|
std::unordered_set<OpHandleBase *> delayed_ops;
|
|
|
|
|
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
|
|
|
|
|
std::unordered_set<VarHandleBase *> delayed_vars;
|
|
|
|
|
|
|
|
|
|
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
|
|
|
|
|
pending_vars.insert(&var);
|
|
|
|
@ -122,24 +113,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
InsertPendingOp(*op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto run_all_ready_ops = [&] {
|
|
|
|
|
for (auto *op : ready_ops) {
|
|
|
|
|
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
|
|
|
|
|
delayed_ops.insert(op);
|
|
|
|
|
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
|
|
|
|
|
ready_vars.Extend(op->outputs_);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
|
|
|
|
|
for (auto *op : set) {
|
|
|
|
|
running_ops_++;
|
|
|
|
|
RunOp(&ready_vars, op);
|
|
|
|
|
}
|
|
|
|
|
ready_ops.clear();
|
|
|
|
|
set.clear();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Step 3. Execution
|
|
|
|
|
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
|
|
|
|
|
while (!pending_vars.empty()) {
|
|
|
|
|
// 1. Run All Ready ops
|
|
|
|
|
run_all_ready_ops();
|
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
|
//
|
|
|
|
|
// NOTE: DelayedOps have a lower priority. It will be scheduled after all
|
|
|
|
|
// ready_ops have been performed.
|
|
|
|
|
if (ready_ops.empty() && allow_op_delay_) {
|
|
|
|
|
run_all_ops(delayed_ops);
|
|
|
|
|
} else {
|
|
|
|
|
run_all_ops(ready_ops);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 2. Find ready variable
|
|
|
|
|
bool timeout;
|
|
|
|
@ -160,29 +153,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
auto &deps = pending_ops[op];
|
|
|
|
|
--deps;
|
|
|
|
|
if (deps == 0) {
|
|
|
|
|
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
|
|
|
|
|
blocked_by_delayed_ops.insert(op);
|
|
|
|
|
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
|
|
|
|
|
delayed_ops.insert(op);
|
|
|
|
|
} else {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// When there are no other ops to schedule, schedule buffered delayed
|
|
|
|
|
// ops and unblock other ops.
|
|
|
|
|
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
|
|
|
|
|
RunDelayedOps(delayed_ops);
|
|
|
|
|
delayed_ops.clear();
|
|
|
|
|
for (auto *op : blocked_by_delayed_ops) {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
blocked_by_delayed_ops.clear();
|
|
|
|
|
}
|
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ready_ops.empty());
|
|
|
|
|
PADDLE_ENFORCE(delayed_ops.empty());
|
|
|
|
|
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
|
|
|
|
|
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
|
if (!fetch_ops.empty()) {
|
|
|
|
|