|
|
|
@ -23,14 +23,15 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
|
|
|
|
size_t num_threads, bool use_event,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
std::unique_ptr<SSAGraph> &&graph)
|
|
|
|
|
std::unique_ptr<SSAGraph> &&graph, bool allow_op_delay)
|
|
|
|
|
: SSAGraphExecutor(std::move(graph)),
|
|
|
|
|
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
places_(places),
|
|
|
|
|
fetch_ctxs_(places),
|
|
|
|
|
use_event_(use_event),
|
|
|
|
|
running_ops_(0) {}
|
|
|
|
|
running_ops_(0),
|
|
|
|
|
allow_op_delay_(allow_op_delay) {}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunDelayedOps(
|
|
|
|
|
const std::unordered_set<OpHandleBase *> &delayed_ops) {
|
|
|
|
@ -119,7 +120,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
auto run_all_ready_ops = [&] {
|
|
|
|
|
for (auto *op : ready_ops) {
|
|
|
|
|
if (op->IsMultiDeviceTransfer()) {
|
|
|
|
|
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
|
|
|
|
|
delayed_ops.insert(op);
|
|
|
|
|
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
|
|
|
|
|
ready_vars.Extend(op->outputs_);
|
|
|
|
@ -138,7 +139,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 3. Execution
|
|
|
|
|
while (!pending_vars.empty()) {
|
|
|
|
|
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
|
|
|
|
|
// 1. Run All Ready ops
|
|
|
|
|
run_all_ready_ops();
|
|
|
|
|
|
|
|
|
@ -181,6 +182,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ready_ops.empty());
|
|
|
|
|
PADDLE_ENFORCE(delayed_ops.empty());
|
|
|
|
|
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
|
|
|
|
|
++computation_count_;
|
|
|
|
|
|
|
|
|
|
auto sync_computation = [&] {
|
|
|
|
|