|
|
|
@ -51,11 +51,22 @@ class ParallelExecutorPrivate {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ResetRuntimeReferenceCount() {
|
|
|
|
|
for (size_t i = 0; i < rt_ref_cnts_.size(); ++i) {
|
|
|
|
|
for (auto &pair : rt_ref_cnts_[i]) {
|
|
|
|
|
rt_cur_ref_cnts_[i][pair.first] = pair.second;
|
|
|
|
|
std::unique_ptr<ir::Graph> PrepareGCAndRefCnts(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph, size_t max_memory_size);
|
|
|
|
|
|
|
|
|
|
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
|
|
|
|
|
|
|
|
|
|
void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetched_var_name) {
|
|
|
|
|
for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
|
|
|
|
|
for (auto &pair : global_ref_cnts_[i]) {
|
|
|
|
|
runtime_ref_cnts_[i][pair.first] = pair.second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &fetch_name : fetch_tensors) {
|
|
|
|
|
runtime_ref_cnts_[i].erase(fetch_name);
|
|
|
|
|
}
|
|
|
|
|
runtime_ref_cnts_[i].erase(fetched_var_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -71,14 +82,75 @@ class ParallelExecutorPrivate {
|
|
|
|
|
bool use_cuda_;
|
|
|
|
|
bool use_all_reduce_;
|
|
|
|
|
|
|
|
|
|
// rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then
|
|
|
|
|
// keeps unchanged
|
|
|
|
|
// Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_
|
|
|
|
|
std::vector<details::ReferenceCountMap> rt_ref_cnts_;
|
|
|
|
|
std::vector<details::AtomicReferenceCountMap> rt_cur_ref_cnts_;
|
|
|
|
|
details::GarbageCollectorList gcs_;
|
|
|
|
|
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
|
|
|
|
|
// then keeps unchanged
|
|
|
|
|
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
|
|
|
|
|
std::vector<details::ReferenceCountMap> global_ref_cnts_;
|
|
|
|
|
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
|
|
|
|
|
details::GarbageCollectorMap gcs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph, size_t max_memory_size) {
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &place = places_[i];
|
|
|
|
|
if (gcs_.count(place) > 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
GarbageCollector<Tensor> *gc = nullptr;
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
if (IsFastEagerDeletionModeEnabled()) {
|
|
|
|
|
gc = new UnsafeFastGPUGarbageCollector<Tensor>(
|
|
|
|
|
boost::get<platform::CUDAPlace>(place), max_memory_size);
|
|
|
|
|
} else {
|
|
|
|
|
gc = new StreamGarbageCollector<Tensor>(
|
|
|
|
|
boost::get<platform::CUDAPlace>(place), max_memory_size);
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
|
|
|
|
|
} else if (platform::is_cpu_place(place)) {
|
|
|
|
|
#endif
|
|
|
|
|
gc = new CPUGarbageCollector<Tensor>(
|
|
|
|
|
boost::get<platform::CPUPlace>(place), max_memory_size);
|
|
|
|
|
VLOG(10) << "Created GarbageCollector at " << place;
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (gc) {
|
|
|
|
|
gcs_[place] = std::unique_ptr<GarbageCollector<Tensor>>(gc);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (gcs_.empty()) {
|
|
|
|
|
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
|
|
|
|
|
|
|
|
|
|
auto ref_cnt_pass =
|
|
|
|
|
ir::PassRegistry::Instance().Get("reference_count_pass");
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
|
|
|
|
|
&global_ref_cnts_);
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
|
|
|
|
|
&last_live_ops_of_vars);
|
|
|
|
|
graph = ref_cnt_pass->Apply(std::move(graph));
|
|
|
|
|
VLOG(10) << "ReferenceCountPass Applied";
|
|
|
|
|
|
|
|
|
|
auto eager_deletion_pass =
|
|
|
|
|
ir::PassRegistry::Instance().Get("eager_deletion_pass");
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
|
|
|
|
|
&runtime_ref_cnts_);
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
|
|
|
|
|
&last_live_ops_of_vars);
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
|
|
|
|
|
graph = eager_deletion_pass->Apply(std::move(graph));
|
|
|
|
|
VLOG(10) << "EagerDeletionPass Applied";
|
|
|
|
|
|
|
|
|
|
graph->SetNotOwned(details::kGarbageCollector, &gcs_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
|
|
|
|
|
return member_->local_scopes_;
|
|
|
|
|
}
|
|
|
|
@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
|
|
|
|
|
auto max_memory_size = GetEagerDeletionThreshold();
|
|
|
|
|
if (max_memory_size >= 0) {
|
|
|
|
|
size_t place_num = member_->places_.size();
|
|
|
|
|
for (size_t i = 0; i < place_num; ++i) {
|
|
|
|
|
auto &place = member_->places_[i];
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
if (IsFastEagerDeletionModeEnabled()) {
|
|
|
|
|
member_->gcs_.emplace_back(new UnsafeFastGPUGarbageCollector<Tensor>(
|
|
|
|
|
boost::get<platform::CUDAPlace>(place), max_memory_size));
|
|
|
|
|
} else {
|
|
|
|
|
member_->gcs_.emplace_back(new StreamGarbageCollector<Tensor>(
|
|
|
|
|
boost::get<platform::CUDAPlace>(place), max_memory_size));
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
|
|
|
|
|
} else if (platform::is_cpu_place(place)) {
|
|
|
|
|
#endif
|
|
|
|
|
member_->gcs_.emplace_back(new CPUGarbageCollector<Tensor>(
|
|
|
|
|
boost::get<platform::CPUPlace>(place), max_memory_size));
|
|
|
|
|
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!member_->gcs_.empty()) {
|
|
|
|
|
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
|
|
|
|
|
|
|
|
|
|
auto ref_cnt_pass =
|
|
|
|
|
ir::PassRegistry::Instance().Get("reference_count_pass");
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
|
|
|
|
|
&(member_->rt_ref_cnts_));
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
|
|
|
|
|
&last_live_ops_of_vars);
|
|
|
|
|
graph = ref_cnt_pass->Apply(std::move(graph));
|
|
|
|
|
VLOG(10) << "ReferenceCountPass Applied";
|
|
|
|
|
|
|
|
|
|
auto eager_deletion_pass =
|
|
|
|
|
ir::PassRegistry::Instance().Get("eager_deletion_pass");
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kCurReferenceCount,
|
|
|
|
|
&(member_->rt_cur_ref_cnts_));
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kGarbageCollector,
|
|
|
|
|
&(member_->gcs_));
|
|
|
|
|
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
|
|
|
|
|
&last_live_ops_of_vars);
|
|
|
|
|
graph = eager_deletion_pass->Apply(std::move(graph));
|
|
|
|
|
VLOG(10) << "EagerDeletionPass Applied";
|
|
|
|
|
|
|
|
|
|
graph->SetNotOwned(details::kGarbageCollector, &(member_->gcs_));
|
|
|
|
|
graph = member_->PrepareGCAndRefCnts(std::move(graph),
|
|
|
|
|
static_cast<size_t>(max_memory_size));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 3. Create vars in each scope. Passes may also create new vars.
|
|
|
|
@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices(
|
|
|
|
|
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetched_var_name) {
|
|
|
|
|
platform::RecordBlock b(0);
|
|
|
|
|
if (!member_->gcs_.empty()) {
|
|
|
|
|
member_->ResetRuntimeReferenceCount();
|
|
|
|
|
size_t n = member_->rt_ref_cnts_.size();
|
|
|
|
|
for (size_t i = 0; i < n; ++i) {
|
|
|
|
|
for (auto &fetch_name : fetch_tensors) {
|
|
|
|
|
member_->rt_cur_ref_cnts_[i].erase(fetch_name);
|
|
|
|
|
}
|
|
|
|
|
member_->rt_cur_ref_cnts_[i].erase(fetched_var_name);
|
|
|
|
|
}
|
|
|
|
|
if (member_->HasGarbageCollectors()) {
|
|
|
|
|
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
|
|
|
|
|
}
|
|
|
|
|
auto fetch_data = member_->executor_->Run(fetch_tensors);
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
|
|
|
|
|