You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/distributed/service/env.h

284 lines
8.3 KiB

// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arpa/inet.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
namespace paddle {
namespace distributed {
struct PSHost {
std::string ip;
uint32_t port;
uint32_t rank;
PSHost() = default;
PSHost(const std::string ip, uint32_t port, uint32_t rank)
: ip(ip), port(port), rank(rank) {}
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
// for pslib
uint64_t serialize_to_uint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
host_label = host_label << 32;
host_label += (port << 12);
host_label += rank;
return host_label;
}
void parse_from_uint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask;
port = (host_label >> 12) & port_label_mask;
uint32_t ip_addr = (host_label >> 32);
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
}
std::string to_string() {
std::stringstream s;
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint: " << serialize_to_uint64();
return s.str();
}
// for open source parameter server
std::string serialize_to_string() {
std::stringstream s;
s << ip << ":";
s << port << ":";
s << rank;
return s.str();
}
void parse_from_string(std::string endpoint) {
std::vector<std::string> endpoint_info;
string_split(endpoint, ':', &endpoint_info);
ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]);
}
void string_split(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
if (!ignore_null) {
pieces->push_back(str);
}
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
};
class PSEnvironment {
public:
explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_servers(
const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(std::string *host_endpoint_list,
int node_num) {
return 0;
}
virtual uint64_t get_local_host_sign() { return 0; }
virtual std::vector<PSHost> get_ps_servers() const { return _ps_server_list; }
virtual int32_t registe_ps_server(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_server_list,
_ps_server_sign_set);
}
virtual std::vector<PSHost> get_ps_clients() const { return _ps_client_list; }
virtual int32_t registe_ps_client(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_client_list,
_ps_client_sign_set);
}
virtual std::vector<uint64_t> get_client_info() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_sign_set) {
client_info.push_back(i);
}
return client_info;
}
virtual std::vector<std::string> get_client_info(bool use_string_endpoint) {
if (use_string_endpoint) {
std::vector<std::string> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_string());
}
return client_info;
}
return {};
}
virtual void set_trainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; }
protected:
//注册一个host // NOLINT
virtual int32_t registe_ps_host(
const std::string &ip, uint32_t port, int32_t rank,
std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT
PSHost host;
host.ip = ip;
host.port = port;
host.rank = rank;
if (sign_set.count(rank) > 0) {
LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port
<< ", rank:" << host.rank
<< " already register, ignore register";
} else {
host_list.push_back(host);
sign_set.insert(rank);
}
return 0;
}
int trainers_ = 0;
std::vector<PSHost> _ps_client_list;
std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter
std::vector<PSHost> _ps_server_list;
std::unordered_set<uint64_t> _ps_server_sign_set; // for unique filter
};
class PaddlePSEnvironment : public PSEnvironment {
public:
explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.serialize_to_uint64());
}
}
std::sort(
_ps_server_list.begin(), _ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t set_ps_servers(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank);
}
}
std::sort(
_ps_server_list.begin(), _ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.serialize_to_uint64());
}
}
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t set_ps_clients(std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank);
}
}
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual uint64_t get_local_host_sign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].serialize_to_uint64();
} else {
return 0;
}
}
};
} // namespace distributed
} // namespace paddle