|
|
|
@ -141,7 +141,6 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
std::vector<std::unique_ptr<ir::Graph>> graphs;
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
|
|
|
|
|
VLOG(1) << "kParallelGraph mode!!";
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
|
|
|
|
|
main_program, {member_->places_[i]}, loss_var_name, params,
|
|
|
|
@ -178,8 +177,8 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
|
|
|
|
|
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
|
|
|
|
|
graphs[0] = ref_cnt_pass->Apply(std::move(graphs[i]));
|
|
|
|
|
graphs[0]->SetNotOwned("garbage_collector", &gcs_);
|
|
|
|
|
graphs[i] = ref_cnt_pass->Apply(std::move(graphs[i]));
|
|
|
|
|
graphs[i]->SetNotOwned("garbage_collector", &gcs_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -192,6 +191,18 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
|
|
|
|
|
// Step 3. Create vars in each scope. Passes may also create new vars.
|
|
|
|
|
// skip control vars and empty vars
|
|
|
|
|
std::vector<details::VariableInfo> var_infos;
|
|
|
|
|
for (auto &graph : graphs) {
|
|
|
|
|
for (auto &node : graph->Nodes()) {
|
|
|
|
|
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
|
|
|
|
|
var_infos.emplace_back();
|
|
|
|
|
var_infos.back().name_ = node->Var()->Name();
|
|
|
|
|
var_infos.back().type_ = node->Var()->GetType();
|
|
|
|
|
var_infos.back().persistable_ = node->Var()->Persistable();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
std::vector<std::vector<details::VariableInfo>> var_infos_list;
|
|
|
|
|
for (size_t i = 0; i < graphs.size(); ++i) {
|
|
|
|
|
std::vector<details::VariableInfo> var_infos;
|
|
|
|
@ -203,8 +214,9 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
var_infos.back().persistable_ = node->Var()->Persistable();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
var_infos_list.emplace_back(std::move(var_infos));
|
|
|
|
|
var_infos_list.push_back(std::move(var_infos));
|
|
|
|
|
}
|
|
|
|
|
**/
|
|
|
|
|
|
|
|
|
|
// If the loss_var_name is given, the number of graph should be only one.
|
|
|
|
|
if (loss_var_name.size()) {
|
|
|
|
@ -236,7 +248,7 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
|
|
|
|
|
exec_strategy, member_->local_scopes_, std::move(var_infos_list),
|
|
|
|
|
exec_strategy, member_->local_scopes_, std::move(var_infos),
|
|
|
|
|
member_->places_, std::move(member_->executor_)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|