From 4c1f9983d73d50373c30105fc4d985254d4698fb Mon Sep 17 00:00:00 2001 From: ZPaC Date: Thu, 10 Sep 2020 21:50:42 +0800 Subject: [PATCH] Optimize PS code. --- .../cpu/ps/embedding_look_up_proxy_kernel.cc | 8 ++-- .../cpu/ps/embedding_look_up_ps_kernel.cc | 6 +-- .../cpu/ps/embedding_look_up_ps_kernel.h | 2 +- .../kernel_compiler/cpu/ps/pserver_kernel.cc | 8 +++- .../kernel_compiler/cpu/ps/pserver_kernel.h | 25 +---------- .../kernel_compiler/cpu/ps/pull_kernel.h | 4 +- .../kernel_compiler/cpu/ps/push_kernel.h | 2 +- .../cpu/ps/sparse_apply_adam_ps_kernel.cc | 7 +--- .../cpu/ps/sparse_apply_adam_ps_kernel.h | 2 +- .../cpu/ps/sparse_apply_ftrl_ps_kernel.cc | 7 +--- .../cpu/ps/sparse_apply_ftrl_ps_kernel.h | 2 +- .../ps/sparse_apply_lazy_adam_ps_kernel.cc | 8 +--- .../cpu/ps/sparse_apply_lazy_adam_ps_kernel.h | 2 +- .../ccsrc/backend/session/session_basic.cc | 8 ++-- .../frontend/parallel/ps/optimizer_info.cc | 41 ++++++++++--------- .../frontend/parallel/ps/optimizer_info.h | 14 +++---- .../parallel/ps/optimizer_info_builder.cc | 39 ++++++++++++++++-- .../parallel/ps/optimizer_info_builder.h | 4 +- .../frontend/parallel/ps/parameter_server.h | 41 +++++++------------ mindspore/ccsrc/frontend/parallel/ps/worker.h | 16 ++++---- .../ccsrc/frontend/parallel/ps/worker_proxy.h | 13 +++--- mindspore/ccsrc/pipeline/jit/action.cc | 2 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 4 +- .../scripts/run_parameter_server_train.sh | 2 +- .../scripts/run_parameter_server_train_gpu.sh | 2 +- model_zoo/official/cv/resnet/train.py | 5 ++- 26 files changed, 136 insertions(+), 138 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index cfc22b1a07..faf00853a5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -42,8 +42,8 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; std::vector lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())}; if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { - parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape[axis]); - parallel::ps::Worker::GetInstance().InitPSEmbeddingTable(keys, values, lens); + parallel::ps::worker.AddEmbeddingTable(key_, input_shape[axis]); + parallel::ps::worker.InitPSEmbeddingTable(keys, values, lens); } } @@ -64,8 +64,8 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector &i if (ret != EOK) { MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; } - parallel::ps::Worker::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result, - parallel::ps::kEmbeddingLookupCmd); + parallel::ps::worker.DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result, + parallel::ps::kEmbeddingLookupCmd); auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size); if (ret2 != EOK) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc index 66e154f954..0ea4718fa9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -17,6 +17,7 @@ #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" #include #include +#include #include "backend/kernel_compiler/common_utils.h" #include "frontend/parallel/ps/util.h" @@ -54,9 +55,8 @@ void EmbeddingLookUpPSKernel::InitKernel( output_size_list_.emplace_back(output_size); } -void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - const auto &indices_shape = *(shape_vec[0]); +void EmbeddingLookUpPSKernel::ReInit(const std::vector> &shapes) { + const auto &indices_shape = shapes[0]; indices_lens_ = indices_shape[0]; size_t output_size = sizeof(float) * indices_lens_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h index 815b2b6f77..251aaba5e3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h @@ -31,7 +31,7 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK ~EmbeddingLookUpPSKernel() override = default; void InitKernel(const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; + void ReInit(const std::vector> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc index e6a62a1daa..a7d1189835 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc @@ -14,8 +14,14 @@ * limitations under the License. */ +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" + namespace mindspore { namespace kernel { -namespace ps {} // namespace ps +namespace ps { +void PServerKernel::Shard(std::vector *shape, int axis) { + (*shape)[axis] = Util::LocalShard((*shape)[axis], rank_id_, pserver_num_); +} +} // namespace ps } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h index ba0f6c8f8f..fe978f454e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -35,7 +35,7 @@ class PServerKernel { virtual void InitKernel(const std::shared_ptr>>> &) {} virtual void InitKernel(const CNodePtr &cnode, const std::shared_ptr>>> &) {} - virtual void ReInit(const std::shared_ptr>>> &) {} + virtual void ReInit(const std::vector> &) {} virtual bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) = 0; @@ -45,32 +45,11 @@ class PServerKernel { protected: virtual void ReInit(const std::vector &) {} - - void SetTotalRowCnt(size_t total_cnt) { - MS_LOG(INFO) << "Total row count of server " << rank_id_ << " is " << total_cnt; - total_row_cnt_ = total_cnt; - } - - void CalOffset() { - size_t rem = total_row_cnt_ % pserver_num_; - if (rem == 0) { - row_offset_ = total_row_cnt_ / pserver_num_ * rank_id_; - } else { - row_offset_ = std::round((static_cast(total_row_cnt_)) / pserver_num_) * rank_id_; - } - MS_LOG(INFO) << "Row offset of server " << rank_id_ << " is " << row_offset_; - } - - void Shard(std::vector *shape, int axis) { - (*shape)[axis] = Util::LocalShard((*shape)[axis], rank_id_, pserver_num_); - } + void Shard(std::vector *shape, int axis); size_t rank_id_; size_t pserver_num_; size_t worker_num_; - - size_t total_row_cnt_; - size_t row_offset_; }; } // namespace ps } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h index 350b503d8b..3b1794306d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -29,11 +29,11 @@ namespace kernel { template class PullKernel : public CPUKernel { public: - PullKernel() : keys_size_(sizeof(size_t)), var_size_(sizeof(size_t)) {} + PullKernel() : key_(UINT64_MAX), keys_size_(sizeof(size_t)), var_size_(sizeof(size_t)) {} ~PullKernel() override = default; bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { - bool init_in_server = mindspore::parallel::ps::Worker::GetInstance().GetParamInitInServer(param_name_); + bool init_in_server = parallel::ps::worker.GetParamInitInServer(param_name_); // If init_in_server, forward kernel should run in server too. if (!init_in_server) { parallel::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h index 800315e5f3..1a31da20a4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h @@ -58,7 +58,7 @@ class PushKernel : public CPUKernel { MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices; for (size_t i = 0; i < optim_input_shapes.size(); i++) { auto shape = optim_input_shapes[i]; - mindspore::parallel::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape); + parallel::ps::worker.SetOptimInputShapes(key_, shape); if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) { size_t size = sizeof(T); for (size_t j = 0; j < shape.size(); j++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc index 25af4fd4f0..5166bc1398 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -31,8 +31,6 @@ void SparseApplyAdamPSKernel::InitKernel( const std::vector &grad_shape = *(shape_vec[9]); const std::vector &indices_shape = *(shape_vec[10]); - SetTotalRowCnt(var_shape[0]); - CalOffset(); Shard(&var_shape, 0); Shard(&m_shape, 0); Shard(&v_shape, 0); @@ -67,9 +65,8 @@ void SparseApplyAdamPSKernel::InitKernel( workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); } -void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - const std::vector &indices_shape = *(shape_vec[0]); +void SparseApplyAdamPSKernel::ReInit(const std::vector> &shapes) { + const std::vector &indices_shape = shapes[0]; indices_size_ = indices_shape[0]; workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h index a28e62abd6..114f518f19 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -33,7 +33,7 @@ class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerK void InitKernel(const CNodePtr &cnode, const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; + void ReInit(const std::vector> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index 9f5f2abfd2..8949745065 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -28,8 +28,6 @@ void SparseApplyFtrlPSKernel::InitKernel( std::vector grad_shape = *(shape_vec[3]); std::vector indices_shape = *(shape_vec[4]); - SetTotalRowCnt(var_shape[0]); - CalOffset(); Shard(&var_shape, 0); Shard(&accum_shape, 0); Shard(&linear_shape, 0); @@ -74,9 +72,8 @@ void SparseApplyFtrlPSKernel::InitKernel( workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); } -void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - std::vector indices_shape = *(shape_vec[0]); +void SparseApplyFtrlPSKernel::ReInit(const std::vector> &shapes) { + const std::vector &indices_shape = shapes[0]; indices_size_ = indices_shape[0]; workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h index 64a3c259af..4ec5d363fd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -33,7 +33,7 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK void InitKernel(const CNodePtr &cnode, const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; + void ReInit(const std::vector> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc index 82fbfc3a2a..eabbefb42f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc @@ -31,8 +31,6 @@ void SparseApplyLazyAdamPSKernel::InitKernel( const std::vector &grad_shape = *(shape_vec[9]); const std::vector &indices_shape = *(shape_vec[10]); - SetTotalRowCnt(var_shape[0]); - CalOffset(); Shard(&var_shape, 0); Shard(&m_shape, 0); Shard(&v_shape, 0); @@ -66,10 +64,8 @@ void SparseApplyLazyAdamPSKernel::InitKernel( workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); } -void SparseApplyLazyAdamPSKernel::ReInit( - const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - const std::vector &indices_shape = *(shape_vec[0]); +void SparseApplyLazyAdamPSKernel::ReInit(const std::vector> &shapes) { + const std::vector &indices_shape = shapes[0]; indices_size_ = indices_shape[0]; workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h index 070f42b96c..3d232887bd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h @@ -33,7 +33,7 @@ class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public void InitKernel(const CNodePtr &cnode, const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; + void ReInit(const std::vector> &) override; bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 1de96c166c..9af3d9eedc 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1393,7 +1393,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { size_t embedding_table_idx = 0; auto embedding_table = AnfAlgo::GetInputNode(node->cast(), embedding_table_idx); - size_t key = parallel::ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope()); + size_t key = parallel::ps::worker.SetParamKey(embedding_table->fullname_with_scope()); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) { auto pull_node = FindPullNode(node, node_list); @@ -1404,12 +1404,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { // Second input of Pull node is the trainable parameter. size_t parameter_index = 1; auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast(), parameter_index); - size_t key = parallel::ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope()); + size_t key = parallel::ps::worker.SetParamKey(parameter_node->fullname_with_scope()); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node); std::string optimizer_name = AnfAlgo::GetNodeAttr(node, kAttrOptimizerType); - parallel::ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name); + parallel::ps::worker.SetKeyOptimId(key, optimizer_name); } } } @@ -1440,7 +1440,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { auto pk_node = input_node->cast(); - mindspore::parallel::ps::Worker::GetInstance().InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor); + parallel::ps::worker.InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor); } } } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 94afc07fad..fd23944714 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -15,6 +15,7 @@ */ #include "frontend/parallel/ps/optimizer_info.h" +#include #include #include "frontend/parallel/ps/util.h" @@ -37,13 +38,6 @@ size_t OptimizerInfo::grad_index() { return 0; } size_t OptimizerInfo::indices_index() { return 0; } -void OptimizerInfo::UpdateWeight(const WeightPtr &weight) { - AddressPtr weight_addr = std::make_shared(); - weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); - inputs_[0] = weight_addr; -} - void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { float *accum_grad_data = reinterpret_cast(gradient()->addr); size_t size = gradient()->size / sizeof(float); @@ -60,8 +54,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { } } -void DenseOptimInfo::ComputeMean(const std::shared_ptr>>> &, size_t n, - size_t server_num, size_t rank_id) { +void DenseOptimInfo::ComputeMean(const std::vector> &, size_t n, size_t, size_t) { if (n > 1) { float *accum_grad_data = reinterpret_cast(gradient()->addr); size_t size = gradient()->size / sizeof(float); @@ -88,6 +81,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { auto ret = memcpy_s(accum_grad_data + grads_offset_, incr_grad_size, incr_grad_data, incr_grad_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } grads_offset_ += lengths[grad_index]; gradient()->size += incr_grad_size; @@ -107,13 +101,14 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { memcpy_s(accum_indices_data + indices_offset_, incr_indice_data_size, incr_indice_data, incr_indice_data_size); if (ret2 != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + return; } indices_offset_ += lengths[indices_index]; indices()->size += incr_indice_data_size; } -void SparseOptimInfo::ComputeMean(const std::shared_ptr>>> &shapes, - size_t n, size_t server_num, size_t rank_id) { +void SparseOptimInfo::ComputeMean(const std::vector> &shapes, size_t n, size_t server_num, + size_t rank_id) { size_t indices_size = static_cast(indices()->size / sizeof(int)); int segment_size = gradient()->size / indices()->size; @@ -121,16 +116,15 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr new_indices(indices_size); mindspore::kernel::SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size}); - const std::vector>> &shape_vec = *shapes; - if (shape_vec.size() < 2 || shape_vec[1] == nullptr) { + // const std::vector>> &shape_vec = *shapes; + if (shapes.size() < 2 || shapes[1].empty()) { MS_LOG(EXCEPTION) << "No input shape found"; } - auto input_shapes = shape_vec.size() > 0 ? shape_vec[1] : nullptr; - MS_EXCEPTION_IF_NULL(input_shapes); - if (input_shapes->size() == 0) { + auto input_shapes = shapes[1]; + if (input_shapes.size() == 0) { MS_LOG(EXCEPTION) << "Invalid input shapes"; } - int first_dim_size = input_shapes->front(); + int first_dim_size = input_shapes.front(); int outer_dim_size = segment_size; if (first_dim_size == 0 || outer_dim_size == 0) { @@ -140,7 +134,7 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr(gradient()->addr); int *indices_data = reinterpret_cast(indices()->addr); - size_t original_row_count = input_shapes->front(); + size_t original_row_count = input_shapes.front(); if (original_row_count > 0) { size_t offset = 0; std::map rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num); @@ -162,11 +156,13 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptraddr, reduced_grad_size, unique_sparse_grad.value_, reduced_grad_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } int reduced_indice_size = unique_sparse_grad.indices_size_ * sizeof(int); ret = memcpy_s(indices()->addr, reduced_indice_size, unique_sparse_grad.indices_, reduced_indice_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } gradient()->size = reduced_grad_size; @@ -197,11 +193,12 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr } void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { - size_t lr_offset = 0; + const size_t lr_offset = 0; float *lr = values.data() + lr_offset; auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } } @@ -243,6 +240,7 @@ void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { auto ret = memcpy_s(beta1_power->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += size; @@ -251,6 +249,7 @@ void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { ret = memcpy_s(beta2_power->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += size; @@ -259,6 +258,7 @@ void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { ret = memcpy_s(lr->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += size; @@ -267,6 +267,7 @@ void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { ret = memcpy_s(beta1->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += size; @@ -275,6 +276,7 @@ void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { ret = memcpy_s(beta2->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += size; @@ -283,6 +285,7 @@ void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { ret = memcpy_s(epsilon->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index 0edb2d2156..5ed96283c0 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -18,7 +18,6 @@ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_OPTIMIZER_INFO_H_ #include -#include #include "backend/kernel_compiler/kernel.h" #include "frontend/parallel/ps/common.h" @@ -32,10 +31,9 @@ class OptimizerInfo { virtual ~OptimizerInfo() = default; virtual void Update(const Values &values, const Lengths &lengths) {} - virtual void UpdateWeight(const WeightPtr &weight); virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; - virtual void ComputeMean(const std::shared_ptr>>> &shapes, size_t n, - size_t server_num, size_t rank_id) {} + virtual void ComputeMean(const std::vector> &shapes, size_t n, size_t server_num, + size_t rank_id) {} virtual void Reset() {} void AddWorkspace(const AddressPtr &workspace); @@ -62,8 +60,8 @@ class DenseOptimInfo : public OptimizerInfo { ~DenseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; - void ComputeMean(const std::shared_ptr>>> &shapes, size_t n, - size_t server_num, size_t rank_id) override; + void ComputeMean(const std::vector> &shapes, size_t n, size_t server_num, + size_t rank_id) override; void Reset() override; }; @@ -73,8 +71,8 @@ class SparseOptimInfo : public OptimizerInfo { ~SparseOptimInfo() override = default; void Accumulate(const Values &values, const Lengths &lens) override; - void ComputeMean(const std::shared_ptr>>> &shapes, size_t n, - size_t server_num, size_t rank_id) override; + void ComputeMean(const std::vector> &shapes, size_t n, size_t server_num, + size_t rank_id) override; void Reset() override; const size_t indice_size() const override; diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index f0a0491968..9bff9bfbb6 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -51,11 +51,13 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); weight_addr->size = weight->size() * sizeof(float); - void *data_ptr = values.data(); - void *copy_data_ptr = new float[values.size()]; + float *data_ptr = values.data(); + float *copy_data_ptr = new float[values.size()]; auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] copy_data_ptr; + return nullptr; } AddressPtr accumulate = std::make_shared(); accumulate->addr = new float[weight->size()]; @@ -86,6 +88,8 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, int ret = memset_s(m->addr, m->size, 0x00, m->size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(m->addr); + return nullptr; } AddressPtr v = std::make_shared(); v->addr = new float[weight->size()]; @@ -93,13 +97,20 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ret = memset_s(v->addr, v->size, 0x00, v->size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(v->addr); + delete[] reinterpret_cast(m->addr); + return nullptr; } - void *data_ptr = values.data(); - void *copy_data_ptr = new float[values.size()]; + float *data_ptr = values.data(); + float *copy_data_ptr = new float[values.size()]; ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] copy_data_ptr; + delete[] reinterpret_cast(v->addr); + delete[] reinterpret_cast(m->addr); + return nullptr; } AddressPtr beta1_power = std::make_shared(); @@ -134,6 +145,11 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, lens[6] * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(grad->addr); + delete[] copy_data_ptr; + delete[] reinterpret_cast(v->addr); + delete[] reinterpret_cast(m->addr); + return nullptr; } grad->size = lens[6] * sizeof(float); @@ -147,6 +163,12 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ret = memcpy_s(indices->addr, indices_data_size, indices_data, indices_data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(indices->addr); + delete[] reinterpret_cast(grad->addr); + delete[] copy_data_ptr; + delete[] reinterpret_cast(v->addr); + delete[] reinterpret_cast(m->addr); + return nullptr; } indices->size = indices_data_size; @@ -173,6 +195,8 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(linear->addr); + return nullptr; } linear->size = weight->size() * sizeof(float); @@ -183,6 +207,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(grad->addr); + delete[] reinterpret_cast(linear->addr); + return nullptr; } grad->size = lens[0] * sizeof(float); @@ -196,6 +223,10 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ret = memcpy_s(indices->addr, indices_data_size, indices_data, indices_data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + delete[] reinterpret_cast(indices->addr); + delete[] reinterpret_cast(grad->addr); + delete[] reinterpret_cast(linear->addr); + return nullptr; } indices->size = indices_data_size; diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h index 21920938a2..52928fb938 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h @@ -55,14 +55,14 @@ class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { public: OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shpae, size_t worker_num, + const InputsShapePtr &inputs_shape, size_t worker_num, const std::shared_ptr &pserver_kernel) override; }; class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { public: OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shpae, size_t worker_num, + const InputsShapePtr &inputs_shape, size_t worker_num, const std::shared_ptr &pserver_kernel) override; }; } // namespace ps diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 7290451a14..3874254e80 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -31,6 +31,7 @@ #include #include #include +#include #include "ir/func_graph.h" #include "backend/session/session_basic.h" #include "backend/session/anf_runtime_algorithm.h" @@ -85,6 +86,7 @@ class ParameterServer { class ServerHandler { public: explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + ~ServerHandler() = default; void Init(); void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); @@ -124,7 +126,6 @@ class ParameterServer { void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); - int SumOfShapes(const std::vector &shapes) const; bool ReadyForUpdateWeights(); bool ReadyForPush(const Key &key); bool ReadyForPull(const Key &key); @@ -460,11 +461,7 @@ void ParameterServer::InitEmbeddingTable( // Init embedding weight const std::vector &input_shapes = lookup->input_sizes(); - size_t total_dims = 1; - for (auto shape : input_shapes) { - total_dims *= shape; - } - + size_t total_dims = std::accumulate(input_shapes.begin(), input_shapes.end(), 1, std::multiplies()); WeightPtr embedding = std::make_shared(total_dims, 0); T *embedding_data = embedding->data(); std::default_random_engine engine; @@ -517,15 +514,14 @@ void ParameterServer::UpdateWeights() { const std::vector &workspaces = optim_info->workspaces(); const std::vector &outputs = optim_info->outputs(); - std::shared_ptr>>> shapes = - std::make_shared>>>(); - std::shared_ptr> indices_shape = std::make_shared>(); - indices_shape->emplace_back(optim_info->indice_size()); - shapes->push_back(indices_shape); + std::vector> shapes = {}; + std::vector indices_shape = {}; + indices_shape.emplace_back(optim_info->indice_size()); + shapes.push_back(indices_shape); if (original_optim_inputs_shape_.count(key) != 0) { - for (auto &input_shapes : *(original_optim_inputs_shape_[key])) { - shapes->push_back(input_shapes); + for (auto input_shapes : *(original_optim_inputs_shape_[key])) { + shapes.push_back(*input_shapes); } } optimizer->ReInit(shapes); @@ -604,11 +600,10 @@ void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; // Update shapes of lookup operator - std::shared_ptr>>> shapes = - std::make_shared>>>(); - std::shared_ptr> indices_shape = std::make_shared>(); - indices_shape->emplace_back(lookup_ids.size()); - shapes->push_back(indices_shape); + std::vector> shapes = {}; + std::vector indices_shape = {}; + indices_shape.emplace_back(lookup_ids.size()); + shapes.push_back(indices_shape); table_lookup_op->ReInit(shapes); const std::vector output_shapes = table_lookup_op->output_sizes(); @@ -641,15 +636,6 @@ void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, res->lens.push_back(res->vals.size()); } -template -int ParameterServer::SumOfShapes(const std::vector &shapes) const { - int sum = 1; - for (auto shape : shapes) { - sum *= shape; - } - return sum; -} - template inline bool ParameterServer::ReadyForUpdateWeights() { return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); @@ -726,6 +712,7 @@ void ParameterServer::SyncEmbeddingTables() { int ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } auto paramter_tensor_ptr = embedding_table.second->default_param(); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index 29b927511c..d4967f53c7 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include "ps/ps.h" #include "utils/log_adapter.h" @@ -124,10 +126,7 @@ void Worker::Push(const std::vector &keys, std::vector add indice_index = 1; } - size_t total_size = 0; - for (auto size : sizes) { - total_size += size; - } + size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus()); ::ps::SArray total_buffer(total_size, 0); size_t offset = 0; for (size_t i = 0; i < sizes.size(); i++) { @@ -135,6 +134,7 @@ void Worker::Push(const std::vector &keys, std::vector add reinterpret_cast(addrs[i]), sizes[i] * sizeof(T)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += sizes[i] * sizeof(T); } @@ -147,10 +147,7 @@ void Worker::Push(const std::vector &keys, std::vector add } else { std::vector &var_shape = key_to_optim_shapes_[key][0]; int first_dim_size = var_shape[0]; - int outer_dim_size = 1; - for (size_t i = 1; i < var_shape.size(); ++i) { - outer_dim_size *= var_shape[i]; - } + int outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies()); kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes), grad_index, indice_index, first_dim_size, outer_dim_size); } @@ -166,6 +163,7 @@ void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { auto ret = memcpy_s(dev_addr, size, variables.data(), size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } } @@ -349,6 +347,8 @@ void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) } kv_worker_->AddEmbeddingTable(key, row_count); } + +static Worker &worker = Worker::GetInstance(); } // namespace ps } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index 3ed245bfce..e32e90ef3a 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -18,6 +18,8 @@ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_ #include +#include +#include #include #include #include @@ -247,7 +249,7 @@ void WorkerProxy::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const : kvs.keys = keys; kvs.vals = vals; kvs.lens = lens; - int cmd = 0; + const int cmd = 0; if (embedding_table_ranges_.count(keys[0])) { std::map attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}}; Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs); @@ -319,6 +321,7 @@ int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps: auto ret = memcpy_s(result_addr + offset, size, pair->first, size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += pair->second; } @@ -493,10 +496,7 @@ void WorkerProxy::SparseSlicer(int timestamp, const ::ps::KVPairs &send, c reduced_lens[indice_index] = unique_sparse_grad.indices_size_; // Build the sparse value to be sent - size_t total_size = 0; - for (auto size : reduced_lens) { - total_size += size; - } + size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus()); ::ps::SArray reduced_data(total_size, 0); BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_, unique_sparse_grad.indices_, &reduced_data); @@ -536,6 +536,7 @@ void WorkerProxy::PrepareSparseGradient(const size_t begin, const size_t end, auto ret = memcpy_s(gradient + offset, segment_data_size, pair.second, segment_data_size); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return; } offset += segment_size; } @@ -566,6 +567,7 @@ void WorkerProxy::BuildSparseValue(const ::ps::SArray &lengths, const si auto ret = memcpy_s(reduced_data->data() + grad_offset, data_size, grads, data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } // Fill the reduced indice @@ -575,6 +577,7 @@ void WorkerProxy::BuildSparseValue(const ::ps::SArray &lengths, const si ret = memcpy_s(indice_data, data_size, indices, data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; } } diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 928c317848..202d079326 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -404,7 +404,7 @@ bool ExecuteAction(const ResourcePtr &res) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) bool StartPSWorkerAction(const ResourcePtr &res) { - parallel::ps::Worker::GetInstance().Run(); + parallel::ps::worker.Run(); return true; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 5dc6e94fe7..9f1badfbf9 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -996,9 +996,9 @@ void ClearResAtexit() { pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (mindspore::parallel::ps::Util::IsParamServerMode()) { + if (parallel::ps::Util::IsParamServerMode()) { if (parallel::ps::Util::IsRoleOfWorker()) { - parallel::ps::Worker::GetInstance().Finalize(); + parallel::ps::worker.Finalize(); } } #endif diff --git a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh index a041aef04e..d5b4b7e673 100644 --- a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh @@ -81,7 +81,7 @@ export RANK_TABLE_FILE=$PATH1 export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=$RANK_SIZE -export MS_SERVER_NUM=1 +export MS_SERVER_NUM=8 export MS_SCHED_HOST=127.0.0.1 export MS_SCHED_PORT=8081 diff --git a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh index ecbb345eed..68276a842b 100755 --- a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh +++ b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh @@ -73,7 +73,7 @@ export RANK_SIZE=8 export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=8 -export MS_SERVER_NUM=1 +export MS_SERVER_NUM=8 export MS_SCHED_HOST=127.0.0.1 export MS_SCHED_PORT=8081 diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 0cc619b9ba..6c76f37128 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -70,7 +70,8 @@ if __name__ == '__main__': # init context context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) - context.set_ps_context(enable_ps=True) + if args_opt.parameter_server: + context.set_ps_context(enable_ps=True) if args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) @@ -161,7 +162,7 @@ if __name__ == '__main__': else: loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - if args_opt.net == "resnet101" or args_opt.net == "resnet50": + if (args_opt.net == "resnet101" or args_opt.net == "resnet50") and not args_opt.parameter_server: opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)