diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 79b390dde4..5ce92ad826 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); - if (strategy_.num_iteration_per_run_ > 1) { - int read_op_num = 0; - for (auto *node : graphs_[0]->Nodes()) { - if (node->IsOp() && node->Name() == "read") { - read_op_num++; - } - } - if (read_op_num == 0) { - LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model " - "should use pyreader to feed data!"; - } - } - // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() ? 1UL @@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run( for (size_t i = 0; i < places_.size(); ++i) { auto call = [this, i, &fetch_tensors]() -> FeedFetchList { try { - for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) { - executors_[i]->Run(fetch_tensors); - } return executors_[i]->Run(fetch_tensors); } catch (...) { exception_holder_.Catch(std::current_exception()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 677a293794..16fa2a6db6 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( places_(places), fetch_ctxs_(places), running_ops_(0), - strategy_(strategy) {} + strategy_(strategy) { + if (strategy_.num_iteration_per_run_ > 1) { + int read_op_num = 0; + for (auto *node : graph_->Nodes()) { + if (node->IsOp() && node->Name() == "read") { + read_op_num++; + } + } + if (read_op_num == 0) { + LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model " + "should use pyreader to feed data!"; + } + } +} -FeedFetchList ThreadedSSAGraphExecutor::Run( +inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( const std::vector &fetch_tensors) { std::unique_ptr event( new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr)); @@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( return fetch_data; } +FeedFetchList ThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) { + RunImpl({}); + } + return RunImpl(fetch_tensors); +} + void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector *fetch_ops, diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 24da56c09e..3809b6e9ae 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ~ThreadedSSAGraphExecutor() final = default; private: + inline FeedFetchList RunImpl(const std::vector &fetch_tensors); void RunOp(const std::shared_ptr> &ready_var_q, details::OpHandleBase *op);