Merge pull request #1051 from jacquesqiao/add-pserver-util
Add ParameterServerController for parameter server python apiavx_docs
commit
f8a529cf1d
@ -0,0 +1,102 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "ParameterServerController.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
ParameterServerController::ParameterServerController(
|
||||
const ParameterServerConfig& config) {
|
||||
// round robin to load balance RDMA server ENGINE
|
||||
std::vector<std::string> devices;
|
||||
int rdmaCpu = 0;
|
||||
int onlineCpus = rdma::numCpus();
|
||||
int numPorts = config.ports_num() + config.ports_num_for_sparse();
|
||||
|
||||
if (config.nics().empty()) {
|
||||
parameterServers_.resize(numPorts);
|
||||
for (int i = 0; i < numPorts; ++i) {
|
||||
if (config.rdma_tcp() == "rdma") {
|
||||
parameterServers_[i].reset(
|
||||
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
|
||||
rdmaCpu = rdmaCpu % onlineCpus;
|
||||
} else {
|
||||
parameterServers_[i].reset(
|
||||
new ParameterServer2(std::string(), config.port() + i));
|
||||
}
|
||||
CHECK(parameterServers_[i]->init()) << "Fail to initialize parameter "
|
||||
"server on port "
|
||||
<< config.port() + i;
|
||||
}
|
||||
} else {
|
||||
str::split(config.nics(), ',', &devices);
|
||||
parameterServers_.resize(devices.size() * numPorts);
|
||||
for (int i = 0; i < numPorts; ++i) {
|
||||
for (size_t j = 0; j < devices.size(); ++j) {
|
||||
if (config.rdma_tcp() == "rdma") {
|
||||
parameterServers_[i * devices.size() + j].reset(new ParameterServer2(
|
||||
getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
|
||||
rdmaCpu = rdmaCpu % onlineCpus;
|
||||
} else {
|
||||
parameterServers_[i * devices.size() + j].reset(
|
||||
new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
|
||||
}
|
||||
CHECK(parameterServers_[i * devices.size() + j]->init())
|
||||
<< "Fail to initialize parameter server with device " << devices[j]
|
||||
<< config.port() + i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ParameterServerController::~ParameterServerController() { this->wait(); }
|
||||
|
||||
ParameterServerController* ParameterServerController::createFromGflags() {
|
||||
ParameterServerConfig config;
|
||||
|
||||
config.set_nics(FLAGS_nics);
|
||||
config.set_rdma_tcp(FLAGS_rdma_tcp);
|
||||
config.set_port(FLAGS_port);
|
||||
config.set_ports_num(FLAGS_ports_num);
|
||||
config.set_ports_num_for_sparse(FLAGS_ports_num_for_sparse);
|
||||
|
||||
return create(config);
|
||||
}
|
||||
|
||||
ParameterServerController* ParameterServerController::create(
|
||||
const ParameterServerConfig& config) {
|
||||
return new ParameterServerController(config);
|
||||
}
|
||||
|
||||
void ParameterServerController::start() {
|
||||
LOG(INFO) << "number of parameterServer instances: "
|
||||
<< parameterServers_.size();
|
||||
int i = 0;
|
||||
for (const auto& parameterServer : parameterServers_) {
|
||||
LOG(INFO) << "Starting parameterServer[" << i << "]";
|
||||
parameterServer->start();
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
void ParameterServerController::wait() {
|
||||
int i = 0;
|
||||
for (const auto& parameterServer : parameterServers_) {
|
||||
LOG(INFO) << "Waiting parameterServer[" << i << "]";
|
||||
parameterServer->join();
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,74 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ParameterServer2.h"
|
||||
#include "ParameterServerConfig.pb.h"
|
||||
#include "RDMANetwork.h"
|
||||
#include "paddle/utils/StringUtil.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief ParameterServerController is used for create, init and manage multi
|
||||
* parameter server instances. The num of the instances is decided by port
|
||||
* num(the ports number for parameter send) and network devices configured
|
||||
* by gflags or proto.
|
||||
*/
|
||||
class ParameterServerController final {
|
||||
public:
|
||||
DISABLE_COPY(ParameterServerController);
|
||||
|
||||
/**
|
||||
* @brief Ctor, Create a ParameterServerController from ParameterServerConfig.
|
||||
*/
|
||||
explicit ParameterServerController(const ParameterServerConfig& config);
|
||||
|
||||
/**
|
||||
* @brief Dtor.
|
||||
*/
|
||||
~ParameterServerController();
|
||||
|
||||
/**
|
||||
* @brief create ParameterServerController from gflags, this is used for
|
||||
* compatibility with the old usage of configuration by gflags.
|
||||
*/
|
||||
static ParameterServerController* createFromGflags();
|
||||
|
||||
/**
|
||||
* @brief create ParameterServerController with ParameterServerConfig, remove
|
||||
* gflags from ParameterServer. Init all ParameterServer2 instances according
|
||||
* to
|
||||
* the config.
|
||||
*/
|
||||
static ParameterServerController* create(const ParameterServerConfig& config);
|
||||
|
||||
/**
|
||||
* @brief start all ParameterServer2 instances in this
|
||||
* ParameterServerController.
|
||||
*/
|
||||
void start();
|
||||
|
||||
/**
|
||||
* @brief join and wait for all ParameterServer2 instances thread in this
|
||||
* ParameterServerController.
|
||||
*/
|
||||
void wait();
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<ParameterServer2>> parameterServers_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
syntax = "proto2";
|
||||
|
||||
package paddle;
|
||||
|
||||
|
||||
/**
|
||||
* Configuration structure for ParameterClient2.
|
||||
*/
|
||||
message ParameterClientConfig {
|
||||
required int32 trainer_id = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration structure for ParameterServer2.
|
||||
*/
|
||||
message ParameterServerConfig {
|
||||
// The ports number for parameter send,
|
||||
// increment based on default port number
|
||||
required int32 ports_num = 1 [default = 1];
|
||||
// The ports number for parameter send,
|
||||
// increment based on default (port + ports_num
|
||||
required int32 ports_num_for_sparse = 2 [default = 0];
|
||||
// network device name for pservers
|
||||
required string nics = 3 [default = "xgbe0,xgbe1"];
|
||||
required string rdma_tcp = 4 [default = "tcp"];
|
||||
// Listening port for pserver
|
||||
required int32 port = 5 [default = 20134];
|
||||
// number of gradient servers
|
||||
required int32 num_gradient_servers = 6 [default = 1];
|
||||
// number of threads for sync op exec
|
||||
required int32 pserver_num_threads = 7 [default = 1];
|
||||
// control config_.async_lagged_grad_discard_ratio() min value
|
||||
required double async_lagged_ratio_min = 8 [default = 1.0];
|
||||
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf
|
||||
// use it as defalut value
|
||||
required double async_lagged_ratio_default = 9 [default = 1.5];
|
||||
}
|
Loading…
Reference in new issue