|
|
|
@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
// Step 2. Insert FetchOps
|
|
|
|
|
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
|
|
|
|
|
std::vector<std::unique_ptr<ir::Node>> tmp_nodes;
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
|
|
|
|
|
FeedFetchList fetch_data(fetch_tensors.size());
|
|
|
|
|
|
|
|
|
|
InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies,
|
|
|
|
|
&pending_ops, &pending_vars, &ready_vars, &fetch_data);
|
|
|
|
|
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
|
|
|
|
|
&pending_vars, &ready_vars, &fetch_data);
|
|
|
|
|
|
|
|
|
|
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
|
|
|
|
|
for (auto *op : set) {
|
|
|
|
@ -136,9 +135,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
PADDLE_ENFORCE(ready_ops.empty());
|
|
|
|
|
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
|
if (!fetch_ops.empty()) {
|
|
|
|
|
fetch_ops.clear();
|
|
|
|
|
}
|
|
|
|
|
ClearFetchOp(graph_.get(), &fetch_ops);
|
|
|
|
|
|
|
|
|
|
return fetch_data;
|
|
|
|
|
}
|
|
|
|
@ -146,7 +143,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
|
|
|
|
|
std::vector<std::unique_ptr<ir::Node>> *temp_nodes,
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
|
|
|
|
|
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
|
|
|
|
|
std::unordered_set<VarHandleBase *> *pending_vars,
|
|
|
|
@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
|
|
|
|
|
auto &vars = fetched_var_it->second;
|
|
|
|
|
|
|
|
|
|
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
|
|
|
|
|
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
|
|
|
|
|
&local_scopes_);
|
|
|
|
|
ir::Node *fetch_node =
|
|
|
|
|
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
|
|
|
|
|
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_);
|
|
|
|
|
fetch_ops->emplace_back(op);
|
|
|
|
|
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
op->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
|
|
|
|
|
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
|
|
|
|
|
ir::Node *fetch_var =
|
|
|
|
|
graph_->CreateEmptyNode("fetch", ir::Node::Type::kVariable);
|
|
|
|
|
auto *fetch_dummy = new DummyVarHandle(fetch_var);
|
|
|
|
|
op->AddOutput(fetch_dummy);
|
|
|
|
|
fetch_dependencies->emplace(fetch_dummy);
|
|
|
|
|
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
|
|
|
|
|