|
|
|
@ -64,8 +64,6 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
|
|
|
|
|
size_t num_trainers, size_t trainer_id)
|
|
|
|
|
: member_(new ParallelExecutorPrivate(places)) {
|
|
|
|
|
is_alive_.test_and_set();
|
|
|
|
|
|
|
|
|
|
member_->global_scope_ = scope;
|
|
|
|
|
member_->use_cuda_ = exec_strategy.use_cuda_;
|
|
|
|
|
member_->use_all_reduce_ =
|
|
|
|
@ -248,15 +246,6 @@ void ParallelExecutor::BCastParamsToDevices(
|
|
|
|
|
|
|
|
|
|
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetched_var_name) {
|
|
|
|
|
// If ParallelExecutor has been destructed
|
|
|
|
|
// just return
|
|
|
|
|
if (!is_alive_.test_and_set()) return;
|
|
|
|
|
|
|
|
|
|
// If ParallelExecutor is running
|
|
|
|
|
if (is_running_.test_and_set()) {
|
|
|
|
|
PADDLE_THROW("The previous ParallelExecutor::Run() has not stopped");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::RecordBlock b(0);
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (!gcs_.empty()) {
|
|
|
|
@ -270,17 +259,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
try {
|
|
|
|
|
auto fetch_data = member_->executor_->Run(fetch_tensors);
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)
|
|
|
|
|
->GetMutable<FeedFetchList>() = fetch_data;
|
|
|
|
|
is_running_.clear();
|
|
|
|
|
} catch (...) {
|
|
|
|
|
is_running_.clear();
|
|
|
|
|
if (is_alive_.test_and_set()) {
|
|
|
|
|
std::rethrow_exception(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto fetch_data = member_->executor_->Run(fetch_tensors);
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
|
|
|
|
|
fetch_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParallelExecutor::FeedTensorsIntoLocalScopes(
|
|
|
|
@ -318,7 +299,6 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParallelExecutor::~ParallelExecutor() {
|
|
|
|
|
is_alive_.clear();
|
|
|
|
|
if (member_->own_local_scope_) {
|
|
|
|
|
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
|
|
|
|
|
Scope *local_scope = member_->local_scopes_[i];
|
|
|
|
@ -328,10 +308,8 @@ ParallelExecutor::~ParallelExecutor() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (is_running_.test_and_set()) {
|
|
|
|
|
// wait unitl all threads have been stopped
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// member_ must be destructed before gcs_ since the destructor of
|
|
|
|
|
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
|
|
|
|
|
member_.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|