|
|
@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
|
|
|
VLOG(3) << "build AsyncSSAGraphExecutor";
|
|
|
|
VLOG(3) << "build AsyncSSAGraphExecutor";
|
|
|
|
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
|
|
|
|
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
|
|
|
|
|
|
|
|
|
|
|
|
if (strategy_.num_iteration_per_run_ > 1) {
|
|
|
|
|
|
|
|
int read_op_num = 0;
|
|
|
|
|
|
|
|
for (auto *node : graphs_[0]->Nodes()) {
|
|
|
|
|
|
|
|
if (node->IsOp() && node->Name() == "read") {
|
|
|
|
|
|
|
|
read_op_num++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (read_op_num == 0) {
|
|
|
|
|
|
|
|
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
|
|
|
|
|
|
|
|
"should use pyreader to feed data!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// set the correct size of thread pool to each device.
|
|
|
|
// set the correct size of thread pool to each device.
|
|
|
|
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
|
|
|
|
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
|
|
|
|
? 1UL
|
|
|
|
? 1UL
|
|
|
@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
|
|
|
|
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
|
|
|
|
try {
|
|
|
|
try {
|
|
|
|
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
|
|
|
|
|
|
|
|
executors_[i]->Run(fetch_tensors);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return executors_[i]->Run(fetch_tensors);
|
|
|
|
return executors_[i]->Run(fetch_tensors);
|
|
|
|
} catch (...) {
|
|
|
|
} catch (...) {
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|
exception_holder_.Catch(std::current_exception());
|
|
|
|