|
|
|
@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
|
|
|
|
|
// Step 2. Insert FetchOps
|
|
|
|
|
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
|
|
|
|
|
std::vector<DummyVarHandle> dummy_vars;
|
|
|
|
|
FeedFetchList fetch_data(fetch_tensors.size());
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
|
|
|
@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
|
|
|
|
|
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
|
|
|
|
|
auto &var_name = fetch_tensors[i];
|
|
|
|
|
auto &vars = fetched_vars.at(var_name);
|
|
|
|
|
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
|
|
|
|
|
fetch_ops.emplace_back(op);
|
|
|
|
|
|
|
|
|
|
// FIXME: Use new device context
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
|
op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
|
|
|
|
|
}
|
|
|
|
@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
for (auto *var : vars) {
|
|
|
|
|
op->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *fetch_dummy = new DummyVarHandle();
|
|
|
|
|
op->AddOutput(fetch_dummy);
|
|
|
|
|
fetch_dependencies.emplace(fetch_dummy);
|
|
|
|
|
InsertPendingVar(*fetch_dummy);
|
|
|
|
|
InsertPendingOp(*op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|