|
|
|
@ -21,6 +21,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
|
|
|
|
@ -282,10 +283,19 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
graphs.push_back(std::move(graph));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
|
|
|
|
|
main_program, member_->places_, loss_var_name, member_->local_scopes_,
|
|
|
|
|
member_->nranks_, member_->use_cuda_);
|
|
|
|
|
graphs.push_back(std::move(graph));
|
|
|
|
|
if (build_strategy.async_mode_) {
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
|
|
|
|
|
main_program, {member_->places_[i]}, loss_var_name,
|
|
|
|
|
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_);
|
|
|
|
|
graphs.push_back(std::move(graph));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
|
|
|
|
|
main_program, member_->places_, loss_var_name, member_->local_scopes_,
|
|
|
|
|
member_->nranks_, member_->use_cuda_);
|
|
|
|
|
graphs.push_back(std::move(graph));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
auto max_memory_size = GetEagerDeletionThreshold();
|
|
|
|
|
if (max_memory_size >= 0) {
|
|
|
|
@ -323,23 +333,31 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
"please don't pass loss_var_name.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (build_strategy.enable_parallel_graph_) {
|
|
|
|
|
if (build_strategy.async_mode_) {
|
|
|
|
|
VLOG(3) << "use AsyncSSAGraphExecutor";
|
|
|
|
|
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
|
|
|
|
|
exec_strategy, member_->local_scopes_, member_->places_,
|
|
|
|
|
std::move(graphs)));
|
|
|
|
|
} else if (build_strategy.enable_parallel_graph_) {
|
|
|
|
|
VLOG(3) << "use ParallelSSAGraphExecutor";
|
|
|
|
|
member_->executor_.reset(new details::ParallelSSAGraphExecutor(
|
|
|
|
|
exec_strategy, member_->local_scopes_, member_->places_,
|
|
|
|
|
std::move(graphs)));
|
|
|
|
|
} else {
|
|
|
|
|
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
|
|
|
|
|
VLOG(3) << "use ThreadedSSAGraphExecutor";
|
|
|
|
|
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
|
|
|
|
|
exec_strategy, member_->local_scopes_, member_->places_,
|
|
|
|
|
std::move(graphs[0])));
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "use FastThreadedSSAGraphExecutor";
|
|
|
|
|
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
|
|
|
|
|
exec_strategy, member_->local_scopes_, member_->places_,
|
|
|
|
|
std::move(graphs[0])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
|
|
|
|
|
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
|
|
|
|
|
exec_strategy, member_->local_scopes_, std::move(var_infos),
|
|
|
|
|
member_->places_, std::move(member_->executor_)));
|
|
|
|
@ -401,14 +419,22 @@ void ParallelExecutor::BCastParamsToDevices(
|
|
|
|
|
auto local_scope = member_->local_scopes_[i];
|
|
|
|
|
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
|
|
|
|
|
if (member_->use_all_reduce_ || member_->use_cuda_ ||
|
|
|
|
|
var == "@LR_DECAY_COUNTER@") {
|
|
|
|
|
auto share_memory = [&] {
|
|
|
|
|
t->Resize(dims);
|
|
|
|
|
t->mutable_data(cpu, main_tensor.type());
|
|
|
|
|
paddle::framework::TensorCopy(main_tensor, cpu, t);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto copy_memory = [&] { t->ShareDataWith(main_tensor); };
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
|
|
|
|
|
if (member_->build_strategy_.async_mode_) {
|
|
|
|
|
share_memory();
|
|
|
|
|
} else if (member_->use_all_reduce_ || member_->use_cuda_ ||
|
|
|
|
|
var == "@LR_DECAY_COUNTER@") {
|
|
|
|
|
copy_memory();
|
|
|
|
|
} else {
|
|
|
|
|
t->ShareDataWith(main_tensor);
|
|
|
|
|
share_memory();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|