|
|
|
@ -21,10 +21,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
|
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#include "paddle/fluid/platform/nccl_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
|
|
|
|
@ -39,6 +35,8 @@ limitations under the License. */
|
|
|
|
|
DEFINE_string(pe_profile_fname, "",
|
|
|
|
|
"Profiler filename for PE, which generated by gperftools."
|
|
|
|
|
"Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable.");
|
|
|
|
|
DEFINE_bool(enable_parallel_graph, true,
|
|
|
|
|
"Force disable parallel graph execution mode if set false.");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -211,15 +209,6 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
"the number of places must be greater than 1.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FIXME(Yancey1989): parallel graph mode get better performance
|
|
|
|
|
// in GPU allreduce distributed training. Need an elegant way to
|
|
|
|
|
// choice the execution strategy.
|
|
|
|
|
build_strategy.enable_parallel_graph_ =
|
|
|
|
|
EnableParallelGraphExecution(main_program, exec_strategy, build_strategy);
|
|
|
|
|
|
|
|
|
|
VLOG(1) << "Enable ParallelGraph Execution: "
|
|
|
|
|
<< build_strategy.enable_parallel_graph_;
|
|
|
|
|
|
|
|
|
|
// Step 1. Bcast the bcast_vars to devs.
|
|
|
|
|
// Create local scopes
|
|
|
|
|
if (local_scopes.empty()) {
|
|
|
|
@ -236,24 +225,35 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FIXME(Yancey1989): parallel graph mode get better performance
|
|
|
|
|
// in GPU allreduce distributed training. Need an elegant way to
|
|
|
|
|
// choice the execution strategy.
|
|
|
|
|
build_strategy.enable_parallel_graph_ =
|
|
|
|
|
EnableParallelGraphExecution(main_program, exec_strategy, build_strategy);
|
|
|
|
|
|
|
|
|
|
VLOG(1) << "Enable ParallelGraph Execution: "
|
|
|
|
|
<< build_strategy.enable_parallel_graph_;
|
|
|
|
|
|
|
|
|
|
if (member_->use_cuda_) {
|
|
|
|
|
// Bcast Parameters to all GPUs
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
ncclUniqueId *nccl_id = nullptr;
|
|
|
|
|
// gen_nccl_id operator can broadcast the ncclUniqueId for nccl2 collective
|
|
|
|
|
// distributed training
|
|
|
|
|
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
|
|
|
|
|
std::unique_ptr<ncclUniqueId> nccl_id;
|
|
|
|
|
// nccl collective would broadcast ncclUniqueId by gen_nccl_id operator.
|
|
|
|
|
if (nccl_id_var != nullptr) {
|
|
|
|
|
nccl_id.reset(nccl_id_var->GetMutable<ncclUniqueId>());
|
|
|
|
|
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
|
|
|
|
|
}
|
|
|
|
|
if (build_strategy.enable_parallel_graph_ && member_->nranks_ > 1UL) {
|
|
|
|
|
if (nccl_id.get() == nullptr) {
|
|
|
|
|
nccl_id.reset(new ncclUniqueId());
|
|
|
|
|
platform::dynload::ncclGetUniqueId(nccl_id.get());
|
|
|
|
|
if (nccl_id == nullptr) {
|
|
|
|
|
local_nccl_id_.reset(new ncclUniqueId());
|
|
|
|
|
platform::dynload::ncclGetUniqueId(local_nccl_id_.get());
|
|
|
|
|
nccl_id = local_nccl_id_.get();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
|
|
|
|
|
member_->places_, nccl_id.get(), num_trainers, trainer_id));
|
|
|
|
|
member_->places_, nccl_id, num_trainers, trainer_id));
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA");
|
|
|
|
|
#endif
|
|
|
|
@ -492,7 +492,7 @@ bool ParallelExecutor::EnableParallelGraphExecution(
|
|
|
|
|
if (build_strategy.enable_sequential_execution_ ||
|
|
|
|
|
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
|
|
|
|
|
enable_parallel_graph = false;
|
|
|
|
|
return enable_parallel_graph;
|
|
|
|
|
return enable_parallel_graph && FLAGS_enable_parallel_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParallelExecutor::~ParallelExecutor() {
|
|
|
|
|