From ec6ee0a2939c53f3d520c9ec51492e39bd1c33ee Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Sep 2018 17:39:29 +0800 Subject: [PATCH 1/2] simplify and hide bcast_params --- paddle/fluid/framework/parallel_executor.cc | 45 +++------------------ paddle/fluid/framework/parallel_executor.h | 2 +- 2 files changed, 7 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5b8c75a93d..48e440bda6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -209,30 +209,9 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToDevices( const std::unordered_set &vars) const { - // the initializing bcast, all vars would be bcast from device(0), - // otherwise - // bcast from the specified device. - bool initializing = member_->executor_ ? false : true; + // the initializing bcast, all vars would be bcast from device(0). for (auto &var : vars) { - int var_dev_id = -1; - if (member_->executor_) { - auto &sharded_var_device = - member_->executor_->Graph().Get( - details::kShardedVarDevice); - if (sharded_var_device.find(var) != sharded_var_device.end()) { - var_dev_id = sharded_var_device.at(var); - } - } - - if (!initializing && var_dev_id == -1) continue; - - framework::Variable *main_var = nullptr; - if (initializing) { - main_var = member_->local_scopes_[0]->FindVar(var); - } else { - main_var = member_->local_scopes_[var_dev_id]->FindVar(var); - } - + framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var); if (main_var == nullptr || !main_var->IsType()) { continue; } @@ -248,8 +227,7 @@ void ParallelExecutor::BCastParamsToDevices( auto place = member_->places_[i]; void *buffer; - if ((initializing && i == 0) || - (!initializing && static_cast(i) == var_dev_id)) { + if (i == 0) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; @@ -266,29 +244,18 @@ void ParallelExecutor::BCastParamsToDevices( platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); - if (initializing) { - platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); - } else { - if (var_dev_id >= 0) { - platform::dynload::ncclBcast(buffers[i], numel, data_type, - var_dev_id, nccl_ctx.comm_, - nccl_ctx.stream()); - } - } + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); } member_->nccl_ctxs_->WaitAll(); } - #else PADDLE_THROW("Not compiled with CUDA"); #endif } else { platform::CPUPlace cpu; for (size_t i = 0; i < member_->places_.size(); ++i) { - if ((initializing && i == 0) || - (!initializing && static_cast(i) == var_dev_id)) - continue; + if (i == 0) continue; auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 5fb748fa20..557d8be227 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -66,9 +66,9 @@ class ParallelExecutor { void Run(const std::vector &fetch_tensors, const std::string &fetched_var_name); + private: void BCastParamsToDevices(const std::unordered_set &vars) const; - private: ParallelExecutorPrivate *member_; }; From e5b322051b13811747bc5244a093cdb22caceeb6 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Sep 2018 18:53:04 +0800 Subject: [PATCH 2/2] clean --- .../details/multi_devices_graph_pass.cc | 3 +++ .../details/multi_devices_graph_pass.h | 18 ++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 250e093a5f..8f319116ab 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; void MultiDevSSAGraphBuilder::Init() const { + all_vars_.clear(); + balance_vars_.clear(); + loss_var_name_ = Get(kLossVarName); places_ = Get>(kPlaces); local_scopes_ = Get>(kLocalScopes); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 1ca8c4b855..47aaa80f4d 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t device_id) const; void Init() const; - private: - mutable std::string loss_var_name_; - mutable std::vector places_; - mutable std::vector local_scopes_; - mutable std::unordered_set grad_names_; - #ifdef PADDLE_WITH_CUDA mutable platform::NCCLContextMap *nccl_ctxs_; #endif @@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t GetAppropriateDeviceID( const std::vector &var_names) const; - private: + void SetCommunicationContext(OpHandleBase *op_handle, + const platform::Place &p) const; + + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; mutable std::vector balance_vars_; - - void SetCommunicationContext(OpHandleBase *op_handle, - const platform::Place &p) const; }; } // namespace details } // namespace framework