|
|
|
@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
|
|
|
|
places_(places),
|
|
|
|
|
fetch_ctxs_(places),
|
|
|
|
|
running_ops_(0),
|
|
|
|
|
strategy_(strategy) {}
|
|
|
|
|
strategy_(strategy) {
|
|
|
|
|
if (strategy_.num_iteration_per_run_ > 1) {
|
|
|
|
|
int read_op_num = 0;
|
|
|
|
|
for (auto *node : graph_->Nodes()) {
|
|
|
|
|
if (node->IsOp() && node->Name() == "read") {
|
|
|
|
|
read_op_num++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (read_op_num == 0) {
|
|
|
|
|
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
|
|
|
|
|
"should use pyreader to feed data!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
std::unique_ptr<platform::RecordEvent> event(
|
|
|
|
|
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
|
|
|
|
@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
return fetch_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
|
|
|
|
|
RunImpl({});
|
|
|
|
|
}
|
|
|
|
|
return RunImpl(fetch_tensors);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
std::vector<FetchOpHandle *> *fetch_ops,
|
|
|
|
|