|
|
|
@ -23,22 +23,36 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
|
|
|
|
size_t num_threads, bool use_event,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
std::unique_ptr<SSAGraph> &&graph)
|
|
|
|
|
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
|
|
|
|
|
: SSAGraphExecutor(std::move(graph)),
|
|
|
|
|
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
places_(places),
|
|
|
|
|
fetch_ctxs_(places),
|
|
|
|
|
use_event_(use_event) {}
|
|
|
|
|
use_event_(use_event),
|
|
|
|
|
running_ops_(0),
|
|
|
|
|
allow_op_delay_(allow_op_delay) {}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunDelayedOps(
|
|
|
|
|
const std::unordered_set<OpHandleBase *> &delayed_ops) {
|
|
|
|
|
for (auto op : delayed_ops) {
|
|
|
|
|
op->Run(use_event_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
|
|
|
|
std::unordered_set<VarHandleBase *> pending_vars;
|
|
|
|
|
|
|
|
|
|
BlockingQueue<VarHandleBase *> ready_vars;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<OpHandleBase *> ready_ops;
|
|
|
|
|
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
|
|
|
|
|
// streams from multiple GPUs, it's faster to buffer them and schedule
|
|
|
|
|
// together since we currently cannot overlap computation and memcpy streams.
|
|
|
|
|
// Should revisit it if overlapping is available.
|
|
|
|
|
std::unordered_set<OpHandleBase *> delayed_ops;
|
|
|
|
|
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
|
|
|
|
|
std::unordered_set<VarHandleBase *> delayed_vars;
|
|
|
|
|
|
|
|
|
|
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
|
|
|
|
|
pending_vars.insert(&var);
|
|
|
|
@ -106,7 +120,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
auto run_all_ready_ops = [&] {
|
|
|
|
|
for (auto *op : ready_ops) {
|
|
|
|
|
RunOp(ready_vars, op);
|
|
|
|
|
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
|
|
|
|
|
delayed_ops.insert(op);
|
|
|
|
|
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
|
|
|
|
|
ready_vars.Extend(op->outputs_);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
running_ops_++;
|
|
|
|
|
RunOp(&ready_vars, op);
|
|
|
|
|
}
|
|
|
|
|
ready_ops.clear();
|
|
|
|
|
};
|
|
|
|
@ -118,13 +139,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 3. Execution
|
|
|
|
|
while (!pending_vars.empty()) {
|
|
|
|
|
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
|
|
|
|
|
// 1. Run All Ready ops
|
|
|
|
|
run_all_ready_ops();
|
|
|
|
|
|
|
|
|
|
// 2. Find ready variable
|
|
|
|
|
bool timeout;
|
|
|
|
|
auto cur_ready_vars = ready_vars.PopAll(1000, &timeout);
|
|
|
|
|
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
|
|
|
|
|
|
|
|
|
|
if (timeout) {
|
|
|
|
|
if (exception_) {
|
|
|
|
@ -141,13 +162,29 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
auto &deps = pending_ops[op];
|
|
|
|
|
--deps;
|
|
|
|
|
if (deps == 0) {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
|
|
|
|
|
blocked_by_delayed_ops.insert(op);
|
|
|
|
|
} else {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// When there are no other ops to schedule, schedule buffered delayed
|
|
|
|
|
// ops and unblock other ops.
|
|
|
|
|
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
|
|
|
|
|
RunDelayedOps(delayed_ops);
|
|
|
|
|
delayed_ops.clear();
|
|
|
|
|
for (auto *op : blocked_by_delayed_ops) {
|
|
|
|
|
ready_ops.insert(op);
|
|
|
|
|
}
|
|
|
|
|
blocked_by_delayed_ops.clear();
|
|
|
|
|
}
|
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ready_ops.empty());
|
|
|
|
|
PADDLE_ENFORCE(delayed_ops.empty());
|
|
|
|
|
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
|
|
|
|
|
++computation_count_;
|
|
|
|
|
|
|
|
|
|
auto sync_computation = [&] {
|
|
|
|
@ -182,12 +219,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunOp(
|
|
|
|
|
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
|
|
|
|
|
auto op_run = [&ready_var_q, op, this] {
|
|
|
|
|
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
|
|
|
|
|
auto op_run = [ready_var_q, op, this] {
|
|
|
|
|
try {
|
|
|
|
|
VLOG(10) << op->Name() << " : " << op->DebugString();
|
|
|
|
|
op->Run(use_event_);
|
|
|
|
|
ready_var_q.Extend(op->outputs_);
|
|
|
|
|
running_ops_--;
|
|
|
|
|
ready_var_q->Extend(op->outputs_);
|
|
|
|
|
} catch (platform::EnforceNotMet ex) {
|
|
|
|
|
exception_.reset(new platform::EnforceNotMet(ex));
|
|
|
|
|
} catch (...) {
|
|
|
|
|