|
|
|
@ -14,8 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/fetch_op_handle.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
// Should revisit it if overlapping is available.
|
|
|
|
|
std::unordered_set<OpHandleBase *> delayed_ops;
|
|
|
|
|
|
|
|
|
|
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
|
|
|
|
|
pending_vars.insert(&var);
|
|
|
|
|
if (var.generated_op_ == nullptr) {
|
|
|
|
|
ready_vars.Push(&var);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
|
|
|
|
|
pending_ops.insert({&op_instance, op_instance.Inputs().size()});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Transform SSAGraph to pending_ops & pending_vars
|
|
|
|
|
for (auto &var_map : graph_->vars_) {
|
|
|
|
|
for (auto &name_pair : var_map) {
|
|
|
|
|
for (auto &version_pair : name_pair.second) {
|
|
|
|
|
InsertPendingVar(*version_pair);
|
|
|
|
|
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &var : graph_->dep_vars_) {
|
|
|
|
|
InsertPendingVar(*var);
|
|
|
|
|
InsertPendingVar(&pending_vars, &ready_vars, var.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &op : graph_->ops_) {
|
|
|
|
|
if (op->Inputs().empty()) { // Special case, Op has no input.
|
|
|
|
|
ready_ops.insert(op.get());
|
|
|
|
|
} else {
|
|
|
|
|
InsertPendingOp(*op);
|
|
|
|
|
InsertPendingOp(&pending_ops, op.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 2. Insert FetchOps
|
|
|
|
|
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
|
|
|
|
|
FeedFetchList fetch_data(fetch_tensors.size());
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
|
|
|
|
|
|
|
|
|
for (auto &fetch_var_name : fetch_tensors) {
|
|
|
|
|
for (auto &var_map : graph_->vars_) {
|
|
|
|
|
auto it = var_map.find(fetch_var_name);
|
|
|
|
|
if (it != var_map.end()) {
|
|
|
|
|
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
|
|
|
|
|
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
|
|
|
|
|
auto &var_name = fetch_tensors[i];
|
|
|
|
|
auto &vars = fetched_vars.at(var_name);
|
|
|
|
|
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
|
|
|
|
|
fetch_ops.emplace_back(op);
|
|
|
|
|
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
|
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto *var : vars) {
|
|
|
|
|
op->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
FeedFetchList fetch_data(fetch_tensors.size());
|
|
|
|
|
|
|
|
|
|
auto *fetch_dummy = new DummyVarHandle();
|
|
|
|
|
op->AddOutput(fetch_dummy);
|
|
|
|
|
fetch_dependencies.emplace(fetch_dummy);
|
|
|
|
|
InsertPendingVar(*fetch_dummy);
|
|
|
|
|
InsertPendingOp(*op);
|
|
|
|
|
}
|
|
|
|
|
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
|
|
|
|
|
&pending_vars, &ready_vars, &fetch_data);
|
|
|
|
|
|
|
|
|
|
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
|
|
|
|
|
for (auto *op : set) {
|
|
|
|
@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
return fetch_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
|
|
|
|
|
std::unordered_set<VarHandleBase *> *pending_vars,
|
|
|
|
|
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
|
|
|
|
|
|
|
|
|
for (auto &fetch_var_name : fetch_tensors) {
|
|
|
|
|
for (auto &var_map : graph_->vars_) {
|
|
|
|
|
auto it = var_map.find(fetch_var_name);
|
|
|
|
|
if (it != var_map.end()) {
|
|
|
|
|
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
|
|
|
|
|
auto &var_name = fetch_tensors[i];
|
|
|
|
|
auto &vars = fetched_vars.at(var_name);
|
|
|
|
|
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
|
|
|
|
|
fetch_ops->emplace_back(op);
|
|
|
|
|
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
|
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto *var : vars) {
|
|
|
|
|
op->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *fetch_dummy = new DummyVarHandle();
|
|
|
|
|
op->AddOutput(fetch_dummy);
|
|
|
|
|
fetch_dependencies->emplace(fetch_dummy);
|
|
|
|
|
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
|
|
|
|
|
this->InsertPendingOp(pending_ops, op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::InsertPendingOp(
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
|
|
|
|
|
OpHandleBase *op_instance) const {
|
|
|
|
|
pending_ops->insert({op_instance, op_instance->Inputs().size()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::InsertPendingVar(
|
|
|
|
|
std::unordered_set<VarHandleBase *> *pending_vars,
|
|
|
|
|
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
|
|
|
|
|
pending_vars->insert(var);
|
|
|
|
|
if (var->generated_op_ == nullptr) {
|
|
|
|
|
ready_vars->Push(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void ThreadedSSAGraphExecutor::RunOp(
|
|
|
|
|
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
|
|
|
|
|
auto op_run = [ready_var_q, op, this] {
|
|
|
|
|