!4946 Get server rank id in python and fix multi server error.

Merge pull request !4946 from ZPaC/master-get-server-rank-in-python
pull/4946/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fd9be2ddc2

@ -45,6 +45,22 @@ class PServerKernel {
protected:
virtual void ReInit(const std::vector<AddressPtr> &) {}
void SetTotalRowCnt(size_t total_cnt) {
MS_LOG(INFO) << "Total row count of server " << rank_id_ << " is " << total_cnt;
total_row_cnt_ = total_cnt;
}
void CalOffset() {
size_t rem = total_row_cnt_ % pserver_num_;
if (rem == 0) {
row_offset_ = total_row_cnt_ / pserver_num_ * rank_id_;
} else {
row_offset_ = std::round((static_cast<float>(total_row_cnt_)) / pserver_num_) * rank_id_;
}
MS_LOG(INFO) << "Row offset of server " << rank_id_ << " is " << row_offset_;
}
void Shard(std::vector<size_t> *shape, int axis) {
(*shape)[axis] = Util::LocalShard((*shape)[axis], rank_id_, pserver_num_);
}
@ -52,6 +68,9 @@ class PServerKernel {
size_t rank_id_;
size_t pserver_num_;
size_t worker_num_;
size_t total_row_cnt_;
size_t row_offset_;
};
} // namespace ps
} // namespace kernel

@ -31,6 +31,8 @@ void SparseApplyAdamPSKernel::InitKernel(
const std::vector<size_t> &grad_shape = *(shape_vec[9]);
const std::vector<size_t> &indices_shape = *(shape_vec[10]);
SetTotalRowCnt(var_shape[0]);
CalOffset();
Shard(&var_shape, 0);
Shard(&m_shape, 0);
Shard(&v_shape, 0);
@ -85,7 +87,7 @@ bool SparseApplyAdamPSKernel::Execute(const std::vector<AddressPtr> &inputs, con
ReInit(inputs);
int *indices = reinterpret_cast<int *>(inputs[10]->addr);
for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) {
indices[i] -= rank_id_ * var_first_dim_size_;
indices[i] -= row_offset_;
}
return Launch(inputs, workspace, outputs);
}

@ -28,6 +28,8 @@ void SparseApplyFtrlPSKernel::InitKernel(
std::vector<size_t> grad_shape = *(shape_vec[3]);
std::vector<size_t> indices_shape = *(shape_vec[4]);
SetTotalRowCnt(var_shape[0]);
CalOffset();
Shard(&var_shape, 0);
Shard(&accum_shape, 0);
Shard(&linear_shape, 0);
@ -88,7 +90,7 @@ bool SparseApplyFtrlPSKernel::Execute(const std::vector<AddressPtr> &inputs, con
ReInit(inputs);
int *indices = reinterpret_cast<int *>(inputs[4]->addr);
for (size_t i = 0; i < inputs[4]->size / sizeof(int); i++) {
indices[i] -= rank_id_ * var_first_dim_size_;
indices[i] -= row_offset_;
}
return Launch(inputs, workspace, outputs);
}

@ -31,6 +31,8 @@ void SparseApplyLazyAdamPSKernel::InitKernel(
const std::vector<size_t> &grad_shape = *(shape_vec[9]);
const std::vector<size_t> &indices_shape = *(shape_vec[10]);
SetTotalRowCnt(var_shape[0]);
CalOffset();
Shard(&var_shape, 0);
Shard(&m_shape, 0);
Shard(&v_shape, 0);
@ -86,7 +88,7 @@ bool SparseApplyLazyAdamPSKernel::Execute(const std::vector<AddressPtr> &inputs,
ReInit(inputs);
int *indices = reinterpret_cast<int *>(inputs[10]->addr);
for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) {
indices[i] -= rank_id_ * var_first_dim_size_;
indices[i] -= row_offset_;
}
return Launch(inputs, workspace, outputs);
}

@ -721,6 +721,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
return;
}
Init(func_graph);
Util::SetRankId(rank_id_);
thread_->join();
::ps::Finalize(0, true);
}

@ -22,6 +22,8 @@
namespace mindspore {
namespace parallel {
namespace ps {
int Util::rank_id_ = -1;
std::unordered_map<std::string, int> Util::optimizer_to_ids{
{kApplyMomentum, 0},
{kSparseAdam, 1},
@ -140,6 +142,10 @@ int Util::LocalShard(int first_dim, int rank_id, int server_num) {
return first_dim - (shard_size * (server_num - 1));
}
}
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
int Util::GetRankId() { return rank_id_; }
} // namespace ps
} // namespace parallel
} // namespace mindspore

@ -37,11 +37,14 @@ class Util {
static std::string optimizer_node_name(int id);
static bool is_optimizer(std::string name);
static int LocalShard(int first_dim, int rank_id, int server_num);
static void SetRankId(int rank_id);
static int GetRankId();
private:
static std::unordered_map<std::string, int> optimizer_to_ids;
static std::unordered_map<int, std::string> id_to_optimizers;
static std::unordered_map<int, std::string> id_to_optimizer_nodes;
static int rank_id_;
};
} // namespace ps
} // namespace parallel

@ -41,6 +41,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id)
: Worker(app_id, customer_id) {
server_num_ = ::ps::NumServers();
Util::SetRankId(::ps::MyRank());
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;

@ -33,6 +33,9 @@
#else
#include "runtime/device/gpu/distribution/collective_fake_init.h"
#endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "frontend/parallel/ps/util.h"
#endif
namespace py = pybind11;
using EnvInstance = mindspore::EnvInstance;
@ -322,7 +325,10 @@ PYBIND11_MODULE(_c_expression, m) {
"Init gpu collective communication mode.");
(void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective,
"Finalize gpu collective communication mode.");
#endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
(void)m.def("get_ps_mode_rank", &mindspore::parallel::ps::Util::GetRankId, "Get Worker and PServer rank id.");
#endif
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")

@ -0,0 +1,23 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for parameter server training mode"""
from mindspore._c_expression import get_ps_mode_rank
def _get_ps_mode_rank():
ps_rank = get_ps_mode_rank()
if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not launched yet.")
return ps_rank

@ -280,6 +280,9 @@ class ModelCheckpoint(Callback):
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
if os.getenv("MS_ROLE") == "MS_PSERVER":
from mindspore.parallel._ps_utils import _get_ps_mode_rank
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number.

Loading…
Cancel
Save