|
|
@ -53,27 +53,40 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
|
|
|
|
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
|
|
|
|
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
|
|
|
|
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
|
|
|
|
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
|
|
|
|
CopyOpDeps();
|
|
|
|
CopyOpDeps();
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "ThreadedSSAGraphExecutor::Run";
|
|
|
|
VLOG(10) << "ThreadedSSAGraphExecutor::Run";
|
|
|
|
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
|
|
|
|
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
|
|
|
|
new BlockingQueue<VarHandleBase *>);
|
|
|
|
new BlockingQueue<VarHandleBase *>);
|
|
|
|
auto &pending_ops = op_deps->pending_ops_;
|
|
|
|
auto &pending_ops = op_deps->pending_ops_;
|
|
|
|
auto &pending_vars = op_deps->pending_vars_;
|
|
|
|
auto &pending_vars = op_deps->pending_vars_;
|
|
|
|
auto &ready_ops = op_deps->ready_ops_;
|
|
|
|
auto &ready_ops = op_deps->ready_ops_;
|
|
|
|
|
|
|
|
size_t num_ops = op_deps->num_ops_;
|
|
|
|
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
|
|
|
|
|
|
|
|
// streams from multiple GPUs, it's faster to buffer them and schedule
|
|
|
|
|
|
|
|
// together since we currently cannot overlap computation and memcpy streams.
|
|
|
|
|
|
|
|
// Should revisit it if overlapping is available.
|
|
|
|
|
|
|
|
std::unordered_set<OpHandleBase *> delayed_ops;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Step 2. Insert FetchOps
|
|
|
|
// Step 2. Insert FetchOps
|
|
|
|
std::vector<FetchOpHandle *> fetch_ops;
|
|
|
|
std::vector<OpHandleBase *> fetch_ops;
|
|
|
|
std::unordered_set<VarHandleBase *> fetch_dependencies;
|
|
|
|
std::unordered_set<VarHandleBase *> fetch_dependencies;
|
|
|
|
FeedFetchList fetch_data(fetch_tensors.size());
|
|
|
|
FeedFetchList fetch_data(fetch_tensors.size());
|
|
|
|
|
|
|
|
|
|
|
|
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
|
|
|
|
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
|
|
|
|
&pending_ops, &pending_vars, &fetch_data);
|
|
|
|
&pending_ops, &pending_vars, &fetch_data);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exception_holder_.Clear();
|
|
|
|
|
|
|
|
event.reset(nullptr);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Step 3. Execution
|
|
|
|
|
|
|
|
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
|
|
|
|
|
|
|
|
// If the num_threads is 1, we can record the order of operator's
|
|
|
|
|
|
|
|
// execution in the first iteration, and in subsequent iterations,
|
|
|
|
|
|
|
|
// run the recorded operators directly. This strategy could make the
|
|
|
|
|
|
|
|
// execution faster.
|
|
|
|
|
|
|
|
VLOG(3) << "Run the traced ops.";
|
|
|
|
|
|
|
|
RunTracedOps(traced_ops_);
|
|
|
|
|
|
|
|
RunTracedOps(fetch_ops);
|
|
|
|
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
|
|
|
ExecutionFinal(&fetch_ops);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
traced_ops_.clear();
|
|
|
|
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
|
|
|
|
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
|
|
|
|
for (auto *op : set) {
|
|
|
|
for (auto *op : set) {
|
|
|
|
RunOp(ready_vars, op);
|
|
|
|
RunOp(ready_vars, op);
|
|
|
@ -82,9 +95,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
|
|
|
|
};
|
|
|
|
};
|
|
|
|
// Clean run context
|
|
|
|
// Clean run context
|
|
|
|
run_op_futures_.clear();
|
|
|
|
run_op_futures_.clear();
|
|
|
|
exception_holder_.Clear();
|
|
|
|
|
|
|
|
event.reset(nullptr);
|
|
|
|
|
|
|
|
// Step 3. Execution
|
|
|
|
|
|
|
|
while (!pending_vars.empty()) {
|
|
|
|
while (!pending_vars.empty()) {
|
|
|
|
// 1. Run All Ready ops
|
|
|
|
// 1. Run All Ready ops
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
|
// Keep loop until all vars are ready.
|
|
|
@ -94,14 +105,11 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
|
|
|
|
bool timeout;
|
|
|
|
bool timeout;
|
|
|
|
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
|
|
|
|
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
|
|
|
|
if (timeout) {
|
|
|
|
if (timeout) {
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
|
|
|
VLOG(3) << "caught exception " << exception_holder_.Type()
|
|
|
|
|
|
|
|
<< ", rethrow it";
|
|
|
|
|
|
|
|
for (auto &run_op_future : run_op_futures_) {
|
|
|
|
for (auto &run_op_future : run_op_futures_) {
|
|
|
|
run_op_future.wait();
|
|
|
|
run_op_future.wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
exception_holder_.ReThrow();
|
|
|
|
ExecutionFinal(&fetch_ops);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -121,6 +129,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
PADDLE_ENFORCE(ready_ops.empty());
|
|
|
|
PADDLE_ENFORCE(ready_ops.empty());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
// Wait FetchOps.
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
|
|
|
|
|
|
|
@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
const std::vector<std::string> &fetch_tensors,
|
|
|
|
const std::vector<std::string> &fetch_tensors,
|
|
|
|
std::vector<FetchOpHandle *> *fetch_ops,
|
|
|
|
std::vector<OpHandleBase *> *fetch_ops,
|
|
|
|
std::unordered_set<VarHandleBase *> *fetch_dependencies,
|
|
|
|
std::unordered_set<VarHandleBase *> *fetch_dependencies,
|
|
|
|
std::unordered_set<OpHandleBase *> *ready_ops,
|
|
|
|
std::unordered_set<OpHandleBase *> *ready_ops,
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
|
|
|
@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
|
|
|
|
InsertPendingOp(&pending_ops, op);
|
|
|
|
InsertPendingOp(&pending_ops, op);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators.");
|
|
|
|
|
|
|
|
|
|
|
|
for (auto ready_var : ready_vars) {
|
|
|
|
for (auto ready_var : ready_vars) {
|
|
|
|
pending_vars.erase(ready_var);
|
|
|
|
pending_vars.erase(ready_var);
|
|
|
|
for (auto *op : ready_var->PendingOps()) {
|
|
|
|
for (auto *op : ready_var->PendingOps()) {
|
|
|
@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
|
|
|
|
op_deps_->pending_vars_.end());
|
|
|
|
op_deps_->pending_vars_.end());
|
|
|
|
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
|
|
|
|
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
|
|
|
|
op_deps_->ready_ops_.end());
|
|
|
|
op_deps_->ready_ops_.end());
|
|
|
|
|
|
|
|
op_deps->num_ops_ = op_deps_->num_ops_;
|
|
|
|
return std::unique_ptr<OpDependentData>(op_deps);
|
|
|
|
return std::unique_ptr<OpDependentData>(op_deps);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp(
|
|
|
|
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
|
|
|
|
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
|
|
|
|
details::OpHandleBase *op) {
|
|
|
|
details::OpHandleBase *op) {
|
|
|
|
auto op_run = [ready_var_q, op, this] {
|
|
|
|
auto op_run = [ready_var_q, op, this] {
|
|
|
|
|
|
|
|
RunOpSync(op);
|
|
|
|
try {
|
|
|
|
try {
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
|
|
|
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (LIKELY(!strategy_.dry_run_)) {
|
|
|
|
|
|
|
|
op->Run(strategy_.use_cuda_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(10) << op << " " << op->Name() << " Done ";
|
|
|
|
|
|
|
|
ready_var_q->Extend(op->Outputs());
|
|
|
|
ready_var_q->Extend(op->Outputs());
|
|
|
|
VLOG(10) << op << " " << op->Name() << " Signal posted";
|
|
|
|
VLOG(10) << op << " " << op->Name() << " Signal posted";
|
|
|
|
} catch (...) {
|
|
|
|
} catch (...) {
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
if (pool_) {
|
|
|
|
if (pool_) {
|
|
|
|
run_op_futures_.emplace_back(pool_->enqueue(op_run));
|
|
|
|
run_op_futures_.emplace_back(pool_->enqueue(op_run));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
op_run();
|
|
|
|
op_run();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RecordOps(op);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunTracedOps(
|
|
|
|
|
|
|
|
const std::vector<OpHandleBase *> &traced_ops) {
|
|
|
|
|
|
|
|
for (auto &op : traced_ops) {
|
|
|
|
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
RunOpSync(op);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
|
|
|
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (LIKELY(!strategy_.dry_run_)) {
|
|
|
|
|
|
|
|
op->Run(strategy_.use_cuda_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(10) << op << " " << op->Name() << " Done ";
|
|
|
|
|
|
|
|
} catch (...) {
|
|
|
|
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::ExecutionFinal(
|
|
|
|
|
|
|
|
std::vector<OpHandleBase *> *fetch_ops) {
|
|
|
|
|
|
|
|
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
|
|
|
|
|
|
|
|
ClearFetchOp(graph_, fetch_ops);
|
|
|
|
|
|
|
|
exception_holder_.ReThrow();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
|
|
|
|
|
|
|
|
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
|
|
|
|
|
|
|
|
traced_ops_.emplace_back(op);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace details
|
|
|
|
} // namespace details
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|