helinwang-patch-1
Yu Yang 7 years ago
parent e3144393e3
commit a7b0d5bd26

@ -27,15 +27,16 @@ namespace framework {
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
: places_(places), fetch_dev_ctxs_(places) {}
: places_(places) {}
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_dev_ctxs_;
std::vector<Scope *> local_scopes_;
Scope *global_scope_;
std::unique_ptr<details::SSAGraphExecutor> executor_;
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::unique_ptr<details::SSAGraphExecutor> executor_;
#endif
};
ParallelExecutor::ParallelExecutor(
@ -55,7 +56,9 @@ ParallelExecutor::ParallelExecutor(
}
// Bcast Parameters to all GPUs
BuildNCCLCommunicator();
#ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
#endif
if (platform::is_gpu_place(places[0]) &&
member_->local_scopes_.size() != 1) { // Is CUDA
BCastParamsToGPUs(startup_program);
@ -123,12 +126,6 @@ void ParallelExecutor::BCastParamsToGPUs(
#endif
}
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
#endif
}
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
auto fetch_data = member_->executor_->Run(fetch_tensors);

@ -31,6 +31,8 @@ namespace framework {
class ParallelExecutorPrivate;
class ParallelExecutor {
DISABLE_COPY_AND_ASSIGN(ParallelExecutor);
public:
explicit ParallelExecutor(size_t num_threads,
const std::vector<platform::Place>& places,
@ -46,8 +48,6 @@ class ParallelExecutor {
ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const ProgramDesc& startup_program) const;
void BuildNCCLCommunicator() const;
};
} // namespace framework

Loading…
Cancel
Save