|
|
|
@ -34,32 +34,63 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
|
|
|
|
|
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
|
|
|
|
|
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
|
|
|
|
|
}
|
|
|
|
|
VLOG(1) << "pool size: " << places_.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList ParallelSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::vector<std::future<void>> run_futures;
|
|
|
|
|
FeedFetchList fetch_data;
|
|
|
|
|
std::vector<std::future<FeedFetchList>> run_futures;
|
|
|
|
|
|
|
|
|
|
std::vector<FeedFetchList> fetch_datas;
|
|
|
|
|
FeedFetchList ret;
|
|
|
|
|
|
|
|
|
|
fetch_datas.reserve(places_.size());
|
|
|
|
|
ret.reserve(fetch_tensors.size());
|
|
|
|
|
exception_holder_.Clear();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto call = [this, i] {
|
|
|
|
|
// FIXME(Yancey1989): need to fix fetch data failed.
|
|
|
|
|
std::vector<std::string> empty;
|
|
|
|
|
executors_[i]->Run(empty);
|
|
|
|
|
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
|
|
|
|
|
return executors_[i]->Run(fetch_tensors);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (pool_) {
|
|
|
|
|
run_futures.emplace_back(pool_->enqueue(std::move(call)));
|
|
|
|
|
} else {
|
|
|
|
|
call();
|
|
|
|
|
try {
|
|
|
|
|
fetch_datas.emplace_back(std::move(call()));
|
|
|
|
|
} catch (...) {
|
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (pool_) {
|
|
|
|
|
for (auto &f : run_futures) {
|
|
|
|
|
f.wait();
|
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
f.wait();
|
|
|
|
|
} else {
|
|
|
|
|
try {
|
|
|
|
|
fetch_datas.emplace_back(std::move(f.get()));
|
|
|
|
|
} catch (...) {
|
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (exception_holder_.IsCaught()) {
|
|
|
|
|
exception_holder_.ReThrow();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
|
|
|
|
|
std::vector<const LoDTensor *> lodtensor_ptrs;
|
|
|
|
|
lodtensor_ptrs.reserve(local_scopes_.size());
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
|
|
|
|
|
lodtensor_ptrs.push_back(&fetch_datas.at(scope_idx).at(fetch_idx));
|
|
|
|
|
}
|
|
|
|
|
ret.emplace_back();
|
|
|
|
|
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
|
|
|
|
|
}
|
|
|
|
|
return fetch_data;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|