|
|
|
@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) {
|
|
|
|
|
bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
|
|
|
|
|
|
|
|
|
|
int Util::LocalShard(int first_dim, int rank_id, int server_num) {
|
|
|
|
|
int shard_size = std::round((static_cast<float>(first_dim)) / server_num);
|
|
|
|
|
int remain_size = first_dim % server_num;
|
|
|
|
|
if (remain_size == 0 || rank_id < server_num - 1) {
|
|
|
|
|
return shard_size;
|
|
|
|
|
} else {
|
|
|
|
|
return first_dim - (shard_size * (server_num - 1));
|
|
|
|
|
std::map<int, int> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
|
|
|
|
|
if (shard_dims.count(rank_id) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
|
|
|
|
|
}
|
|
|
|
|
return shard_dims[rank_id];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int server_num) {
|
|
|
|
|
if (rank_id >= server_num) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
|
|
|
|
|
}
|
|
|
|
|
std::map<int, int> shard_dims;
|
|
|
|
|
for (int i = 0; i < server_num; i++) {
|
|
|
|
|
shard_dims[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
if (server_num != static_cast<int>(shard_dims.size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
|
|
|
|
|
}
|
|
|
|
|
int server_index = -1;
|
|
|
|
|
for (int i = 0; i < first_dim; i++) {
|
|
|
|
|
server_index = (server_index + 1) % server_num;
|
|
|
|
|
shard_dims[server_index] = shard_dims[server_index] + 1;
|
|
|
|
|
}
|
|
|
|
|
if (shard_dims.count(rank_id) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
|
|
|
|
|
}
|
|
|
|
|
return shard_dims;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
|
|
|
|
|