|
|
|
@ -34,16 +34,16 @@ class SSAGraphExecutor {
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
explicit SSAGraphExecutor(SSAGraph *graph) : graph_(*graph) {}
|
|
|
|
|
// Steal graph inside
|
|
|
|
|
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
|
|
|
|
|
: graph_(std::move(graph)) {}
|
|
|
|
|
|
|
|
|
|
virtual ~SSAGraphExecutor() {}
|
|
|
|
|
|
|
|
|
|
virtual void Run(Scope *global_scope,
|
|
|
|
|
const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetch_list_name) = 0;
|
|
|
|
|
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
SSAGraph &graph_;
|
|
|
|
|
std::unique_ptr<SSAGraph> graph_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
@ -51,16 +51,17 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
|
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
SSAGraph *graph)
|
|
|
|
|
: SSAGraphExecutor(graph),
|
|
|
|
|
std::unique_ptr<SSAGraph> &&graph)
|
|
|
|
|
: SSAGraphExecutor(std::move(graph)),
|
|
|
|
|
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
places_(places),
|
|
|
|
|
fetch_ctxs_(places),
|
|
|
|
|
use_event_(use_event) {}
|
|
|
|
|
|
|
|
|
|
void Run(Scope *global_scope, const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetch_list_name) override {
|
|
|
|
|
// Run a SSAGraph by a thread pool
|
|
|
|
|
// Use topological sort algorithm
|
|
|
|
|
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override {
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
|
|
|
|
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
|
|
|
|
|
std::unordered_set<OpHandleBase *> ready_ops;
|
|
|
|
@ -74,18 +75,18 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Transform SSAGraph to pending_ops & pending_vars
|
|
|
|
|
for (auto &var_map : graph_.vars_) {
|
|
|
|
|
for (auto &var_map : graph_->vars_) {
|
|
|
|
|
for (auto &name_pair : var_map) {
|
|
|
|
|
for (auto &version_pair : name_pair.second) {
|
|
|
|
|
InsertPendingVar(version_pair.second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &var : graph_.dep_vars_) {
|
|
|
|
|
for (auto &var : graph_->dep_vars_) {
|
|
|
|
|
InsertPendingVar(*var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &op : graph_.ops_) {
|
|
|
|
|
for (auto &op : graph_->ops_) {
|
|
|
|
|
if (op->inputs_.empty()) { // Special case, Op has no input.
|
|
|
|
|
ready_ops.insert(op.get());
|
|
|
|
|
} else {
|
|
|
|
@ -101,7 +102,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
|
|
|
|
|
|
|
|
|
for (auto &fetch_var_name : fetch_tensors) {
|
|
|
|
|
for (auto &var_map : graph_.vars_) {
|
|
|
|
|
for (auto &var_map : graph_->vars_) {
|
|
|
|
|
auto it = var_map.find(fetch_var_name);
|
|
|
|
|
if (it != var_map.end()) {
|
|
|
|
|
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
|
|
|
|
@ -182,8 +183,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
|
fetch_op.WaitAndMergeCPUTensors();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*global_scope->Var(fetch_list_name)->GetMutable<FeedFetchList>() =
|
|
|
|
|
fetch_data;
|
|
|
|
|
return fetch_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~ThreadedSSAGraphExecutor() {}
|
|
|
|
@ -240,8 +240,6 @@ class ParallelExecutorPrivate {
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
|
|
|
|
|
|
|
|
|
|
details::SSAGraph graph_;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraphExecutor> executor_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -274,10 +272,10 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
|
|
|
|
|
params, member_->local_scopes_,
|
|
|
|
|
member_->nccl_ctxs_.get());
|
|
|
|
|
builder.Build(main_program, &member_->graph_);
|
|
|
|
|
auto graph = builder.Build(main_program);
|
|
|
|
|
|
|
|
|
|
member_->executor_.reset(new ThreadedSSAGraphExecutor(
|
|
|
|
|
num_threads, true, member_->local_scopes_, places, &member_->graph_));
|
|
|
|
|
num_threads, true, member_->local_scopes_, places, std::move(graph)));
|
|
|
|
|
|
|
|
|
|
// Step 3. Create vars in each scope;
|
|
|
|
|
for (auto *scope : member_->local_scopes_) {
|
|
|
|
@ -338,8 +336,9 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
|
|
|
|
|
|
|
|
|
|
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetched_var_name) {
|
|
|
|
|
member_->executor_->Run(member_->global_scope_, fetch_tensors,
|
|
|
|
|
fetched_var_name);
|
|
|
|
|
auto fetch_data = member_->executor_->Run(fetch_tensors);
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
|
|
|
|
|
fetch_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|