|
|
|
@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
const std::unordered_set<std::string> &bcast_vars,
|
|
|
|
const std::unordered_set<std::string> &bcast_vars,
|
|
|
|
const ProgramDesc &main_program, const std::string &loss_var_name,
|
|
|
|
const ProgramDesc &main_program, const std::string &loss_var_name,
|
|
|
|
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
|
|
|
|
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
|
|
|
|
bool customize_scale_loss)
|
|
|
|
bool use_default_grad_scale)
|
|
|
|
: member_(new ParallelExecutorPrivate(places)) {
|
|
|
|
: member_(new ParallelExecutorPrivate(places)) {
|
|
|
|
member_->global_scope_ = scope;
|
|
|
|
member_->global_scope_ = scope;
|
|
|
|
|
|
|
|
|
|
|
|
@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
details::MultiDevSSAGraphBuilder builder(
|
|
|
|
details::MultiDevSSAGraphBuilder builder(
|
|
|
|
member_->places_, loss_var_name, params, member_->local_scopes_,
|
|
|
|
member_->places_, loss_var_name, params, member_->local_scopes_,
|
|
|
|
customize_scale_loss, member_->nccl_ctxs_.get());
|
|
|
|
use_default_grad_scale, member_->nccl_ctxs_.get());
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
|
|
|
|
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
|
|
|
|
params, member_->local_scopes_,
|
|
|
|
params, member_->local_scopes_,
|
|
|
|
customize_scale_loss);
|
|
|
|
use_default_grad_scale);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
auto graph = builder.Build(main_program);
|
|
|
|
auto graph = builder.Build(main_program);
|
|
|
|
|
|
|
|
|
|
|
|
|