You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/trainer/RemoteParameterUpdater.cpp

829 lines
25 KiB

/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "RemoteParameterUpdater.h"
#include "Trainer.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/GlobalConstants.h"
P_DECLARE_int32(trainer_id);
P_DECLARE_string(save_dir);
namespace paddle {
static const hl_stream_t kDeviceToHostStream = HPPL_STREAM_1;
static const hl_stream_t kHostToDeviceStream = HPPL_STREAM_2;
static const int kFinishBatchPid = -1;
const std::string RemoteParameterUpdater::kAverage = "average";
const std::string RemoteParameterUpdater::kElasticAverage = "elastic_average";
RemoteParameterUpdater::RemoteParameterUpdater(
const OptimizationConfig& config, int expectedPassCount,
std::unique_ptr<ParameterUpdater>&& localUpdater)
: config_(config),
localUpdater_(std::move(localUpdater)),
numBatches_(0),
passCount_(0),
expectedPassCount_(expectedPassCount),
separateSendAndRecv_(false),
isFirstPass_(true),
useApplyInPserver_(false) {
addParameterType(PARAMETER_MOMENTUM);
}
void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
if (localUpdater_) {
localUpdater_->init(parameters);
for (auto& parameter : parameters) {
parameter->enableType(PARAMETER_DELTA);
}
CHECK(config_.center_parameter_update_method() == kAverage ||
config_.center_parameter_update_method() == kElasticAverage)
<< "unknown center_parameter_update_method";
// modify delta_add_rate
CHECK_GT(FLAGS_num_gradient_servers, 1)
<< "FLAGS_num_gradient_servers should be set in trainer args.";
real delta_add_rate = config_.delta_add_rate() / FLAGS_num_gradient_servers;
config_.set_delta_add_rate(delta_add_rate);
LOG(INFO) << "center parameter in pserver,"
<< " modify delta_add_rate=" << delta_add_rate;
}
if (!FLAGS_use_gpu) {
cpuParameters_ = parameters;
} else {
for (auto& parameter : parameters) {
cpuParameters_.emplace_back(new Parameter(parameter->getConfig(),
/* useGpu= */ false));
cpuParameters_.back()->setID(parameter->getID());
if (localUpdater_) {
cpuParameters_.back()->enableType(PARAMETER_DELTA);
}
}
}
parameterClient_.reset(new ParameterClient2(separateSendAndRecv_));
parameterClient_->init(cpuParameters_);
parameterClient_->setTrainerId(FLAGS_trainer_id);
if (FLAGS_trainer_id == 0) {
parameterClient_->setConfig(config_);
copyParametersFromDevice(PARAMETER_VALUE);
parameterClient_->setParameter();
parameterClient_->setStatus(PSERVER_STATUS_PARAMETER_READY);
} else {
parameterClient_->waitForStatus(PSERVER_STATUS_PARAMETER_READY);
parameterClient_->getParameter();
copyParametersToDevice(PARAMETER_VALUE);
}
if (FLAGS_trainer_id == 0 && (config_.algorithm()
!= TrainAlgorithm::AsyncSGD)) {
startController();
useApplyInPserver_ = useApplyInPserver(config_);
}
}
void RemoteParameterUpdater::startController() {
controllerThread_.reset(new std::thread([this]() { this->controller(); }));
}
void RemoteParameterUpdater::controller() {
ParameterClient2 client(false);
client.init(cpuParameters_);
while (true) {
/*start pass*/ {
client.waitPassStart();
PreparedOperations ops;
ops.addOperation(PSERVER_OP_START_PASS);
client.doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false,
/* releasePass= */ false);
}
while (true) {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_SGD);
client.doOperation(ops,
/* waitForGradient= */ true,
/* sendBackarameter= */ true,
/* releasePass= */ false);
if (client.isPassFinish()) {
break;
}
}
/*finish pass*/ {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_FINISH_PASS);
client.doOperation(ops,
/* waitForGradient= */ true,
/* sendBackarameter= */ true,
/* releasePass= */ true);
}
passCount_++;
if (passCount_ == expectedPassCount_) {
break;
}
}
}
void RemoteParameterUpdater::copyParametersToDevice(
ParameterType parameterType) {
if (!FLAGS_use_gpu) {
return;
}
int numParameters = cpuParameters_.size();
for (int i = 0; i < numParameters; ++i) {
parameters_[i]
->getBuf(parameterType)
->copyFrom(*cpuParameters_[i]->getBuf(parameterType));
if (parameterType == PARAMETER_VALUE) {
parameters_[i]->setValueUpdated();
}
}
}
void RemoteParameterUpdater::copyParametersFromDevice(
ParameterType parameterType) {
if (!FLAGS_use_gpu) {
return;
}
int numParameters = cpuParameters_.size();
for (int i = 0; i < numParameters; ++i) {
cpuParameters_[i]
->getBuf(parameterType)
->copyFrom(*parameters_[i]->getBuf(parameterType));
}
}
void RemoteParameterUpdater::updateImpl(Parameter* para) {
REGISTER_TIMER("update");
if (localUpdater_) {
localUpdater_->update(para);
}
}
void RemoteParameterUpdater::finishBatch(real cost) {
if (localUpdater_) {
localUpdater_->finishBatch(cost);
}
const std::string& algorithm = config_.algorithm();
ParameterUpdateMode mode;
if (algorithm == TrainAlgorithm::AsyncSGD) {
mode = PSERVER_UPDATE_MODE_ASYNC_SGD;
} else if (algorithm == TrainAlgorithm::SGD) {
mode = PSERVER_UPDATE_MODE_ADD_GRADIENT;
} else {
LOG(FATAL) << "Unknown algorithm: " << algorithm;
}
ParameterType sendType;
bool sendBackParameter = true;
if (localUpdater_) {
++numBatches_;
if (numBatches_ % config_.num_batches_per_send_parameter() != 0) {
return;
}
if (config_.center_parameter_update_method() == kElasticAverage) {
parameterClient_->getParameter(PARAMETER_DELTA);
copyParametersToDevice(PARAMETER_DELTA);
sendBackParameter = false; // no need send back after send
// calc delta
for (auto& para : parameters_) {
// DELTA = LOCAL_VALUE - CENTER_VALUE/*store in DELTA*/
para->getBuf(PARAMETER_DELTA)
->add(*para->getBuf(PARAMETER_VALUE), -1.0f, 1.0f);
// when delta send to pserver, pserver will do:
// CENTER_VALUE += alpha * (LOCAL_VALUE - CENTER_VALUE)
}
} else {
// calc delta
for (auto& para : parameters_) {
// DELTA = NEW_VALUE - OLD_VALUE/*store in DELTA*/
para->getBuf(PARAMETER_DELTA)
->add(*para->getBuf(PARAMETER_VALUE), -1.0f, 1.0f);
}
}
sendType = PARAMETER_DELTA;
} else {
// In this case, we perform SGD on pserver.
sendType = PARAMETER_GRADIENT;
}
copyParametersFromDevice(sendType);
{
REGISTER_TIMER("sendAndRecv_dense");
parameterClient_->sendAndReceiveParameter(mode, sendType, batchSize_,
0, // cost = 0
sendBackParameter);
}
if (sendBackParameter) {
copyParametersToDevice(PARAMETER_VALUE);
}
if (localUpdater_) {
if (config_.center_parameter_update_method() == kElasticAverage) {
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
// LOCAL_VALUE += -alpha * (LOCAL_VALUE - CENTER_VALUE)
para->getBuf(PARAMETER_VALUE)
->add(*para->getBuf(PARAMETER_DELTA), -config_.delta_add_rate());
}
} else { // average
// copy value to delta
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
para->getBuf(PARAMETER_DELTA)->copyFrom(*para->getBuf(PARAMETER_VALUE));
}
}
} else {
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
para->getBuf(sendType)->zeroMem();
}
}
}
void RemoteParameterUpdater::startPass() {
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassStart();
} else {
// sync could benifits reducing lagged trainer for async-sgd
// even if sync could not remove all lagged trainer for the
// sake of file loading, buffer etc.
parameterClient_->asyncStartPass();
}
if (localUpdater_) {
localUpdater_->startPass();
numBatches_ = 0;
if (config_.center_parameter_update_method() == kElasticAverage) {
if (!isFirstPass_) {
// restore local value from delta
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
para->getBuf(PARAMETER_VALUE)
->copyFrom(*para->getBuf(PARAMETER_DELTA));
}
}
} else { // average
// copy value to delta
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
para->getBuf(PARAMETER_DELTA)->copyFrom(*para->getBuf(PARAMETER_VALUE));
}
}
}
}
bool RemoteParameterUpdater::finishPass(real cost) {
if (localUpdater_) {
localUpdater_->finishPass();
}
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassFinish();
} else {
parameterClient_->asyncFinishPass();
}
if (localUpdater_) {
if (config_.center_parameter_update_method() == kElasticAverage) {
// backup local value to delta as we will get
// the remote parameter for saving/testing
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
para->getBuf(PARAMETER_DELTA)->copyFrom(*para->getBuf(PARAMETER_VALUE));
}
}
}
parameterClient_->getParameter();
copyParametersToDevice(PARAMETER_VALUE);
isFirstPass_ = false;
return true;
}
void RemoteParameterUpdater::apply() {
if (useApplyInPserver_) {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_APPLY);
parameterClient_->doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false);
parameterClient_->getParameter(
/* recvParameterType= */ PARAMETER_VALUE,
/* sendBackParameterType= */ PARAMETER_APPLY);
copyParametersToDevice(PARAMETER_VALUE);
}
}
void RemoteParameterUpdater::restore() {
if (useApplyInPserver_) {
parameterClient_->getParameter();
copyParametersToDevice(PARAMETER_VALUE);
}
}
ConcurrentRemoteParameterUpdater::ConcurrentRemoteParameterUpdater(
OptimizationConfig config, int passCount,
std::unique_ptr<ParameterUpdater>&& localUpdater)
: RemoteParameterUpdater(config, passCount, std::move(localUpdater)) {
sendThread_.reset(new std::thread([this]() { this->send(); }));
recvThread_.reset(new std::thread([this]() { this->recv(); }));
stopping_ = false;
oneBatchFinished_ = false;
separateSendAndRecv_ = true;
}
ConcurrentRemoteParameterUpdater::~ConcurrentRemoteParameterUpdater() {
stopping_ = true;
sendQueue_.enqueue(0);
sendThread_->join();
recvQueue_.enqueue(0);
recvThread_->join();
}
void ConcurrentRemoteParameterUpdater::finishBatch(real cost) {
if (localUpdater_) {
localUpdater_->finishBatch(cost);
if (!needToUpdateRemotely()) {
++numBatches_;
return;
}
}
sendQueue_.enqueue(kFinishBatchPid);
finishBatchCond_.wait([this]() { return oneBatchFinished_; });
oneBatchFinished_ = false;
{
REGISTER_TIMER("sync_hostToDeviceStream");
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
hl_stream_synchronize(kHostToDeviceStream);
}
}
if (localUpdater_) {
++numBatches_;
}
}
// Use para=NULL to signal the end of one batch
void ConcurrentRemoteParameterUpdater::send(Parameter* para) {
const std::string& algorithm = config_.algorithm();
ParameterUpdateMode mode;
if (algorithm == TrainAlgorithm::AsyncSGD) {
mode = PSERVER_UPDATE_MODE_ASYNC_SGD;
} else if (algorithm == TrainAlgorithm::SGD) {
mode = PSERVER_UPDATE_MODE_ADD_GRADIENT;
} else {
LOG(FATAL) << "Unknown algorithm: " << algorithm;
}
ParameterType sendType;
if (localUpdater_) {
sendType = PARAMETER_DELTA;
} else {
// In this case, we perform SGD on pserver.
sendType = PARAMETER_GRADIENT;
}
std::vector<ParameterSegments> paraSegment;
if (para == NULL) {
parameterClient_->sendParameter(
mode, sendType, paraSegment, batchSize_,
0, // cost=0
true, // sendBackParameter = true
batchStatus_); // batchStatus_ = BATCH_FINISH
} else {
ParameterSegments paraSegTemp;
paraSegment.reserve(1);
paraSegTemp.name = para->getName();
paraSegTemp.id = para->getID();
paraSegment.push_back(paraSegTemp);
{
SetDevice device(para->getDeviceId());
REGISTER_TIMER("copySingleParaFromDevice");
copySingleParaFromDevice(para, sendType);
hl_stream_synchronize(kDeviceToHostStream);
}
parameterClient_->sendParameter(mode, sendType, paraSegment, batchSize_,
0, // cost=0
true, // sendBackParameter = true
batchStatus_);
if (batchStatus_ == BATCH_START) batchStatus_ = BATCH_ON;
}
}
void ConcurrentRemoteParameterUpdater::recv(Parameter* para) {
parameterClient_->recvParameter();
if (para != NULL) {
REGISTER_TIMER("copySingleParaToDevice");
SetDevice device(para->getDeviceId());
copySingleParaToDevice(para, PARAMETER_VALUE);
if (localUpdater_) {
para->getBuf(PARAMETER_DELTA)->copyFrom(*para->getBuf(PARAMETER_VALUE));
} else {
// if cpu, parameter should not changes until recvParameter().
// if gpu, zero mem when send finish
if (!FLAGS_use_gpu) {
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
}
}
}
void ConcurrentRemoteParameterUpdater::recv() {
hl_set_device(FLAGS_gpu_id);
StatPtr stat = getStat("recv");
FOR_TIMING(Timer timer);
while (true) {
int pid;
{
REGISTER_TIMER("recv_dequeue");
pid = recvQueue_.dequeue();
}
if (pid == kFinishBatchPid) {
Parameter* para = NULL;
FOR_TIMING(timer.start());
recv(para);
FOR_TIMING(timer.stop());
FOR_TIMING(stat->addSample(timer.get()));
FOR_TIMING(timer.reset());
finishBatchCond_.notify_all([this] { oneBatchFinished_ = true; });
} else {
if (stopping_) break;
Parameter* para = parameters_[pid].get();
FOR_TIMING(timer.start());
recv(para);
FOR_TIMING(timer.stop());
oneBatchFinished_ = false;
}
}
}
void ConcurrentRemoteParameterUpdater::send() {
hl_set_device(FLAGS_gpu_id);
StatPtr stat = getStat("send");
FOR_TIMING(Timer timer);
while (true) {
int pid;
{
REGISTER_TIMER("send_dequeue");
pid = sendQueue_.dequeue();
}
if (pid == kFinishBatchPid) {
batchStatus_ = BATCH_FINISH;
if (!localUpdater_) {
// if cpu, parameter should not changes until recvParameter().
// if gpu, zeroMem() at the end of batch so that it won't
// interfere with computation.
if (FLAGS_use_gpu) {
REGISTER_TIMER("para_zeroMem");
for (auto& para : parameters_) {
SetDevice device(para->getDeviceId());
para->getBuf(PARAMETER_GRADIENT)->zeroMem();
}
}
}
Parameter* para = NULL;
FOR_TIMING(timer.start());
send(para);
FOR_TIMING(timer.stop());
FOR_TIMING(stat->addSample(timer.get()));
FOR_TIMING(timer.reset());
recvQueue_.enqueue(pid);
} else {
if (stopping_) break;
Parameter* para = parameters_[pid].get();
if (localUpdater_) {
// DELTA = NEW_VALUE - OLD_VALUE/*store in DELTA*/
para->getBuf(PARAMETER_DELTA)
->add(*para->getBuf(PARAMETER_VALUE), -1.0f, 1.0f);
}
FOR_TIMING(timer.start());
send(para);
FOR_TIMING(timer.stop());
recvQueue_.enqueue(nonStaticParaIDMap_[para->getID()]);
}
}
}
void ConcurrentRemoteParameterUpdater::updateImpl(Parameter* para) {
REGISTER_TIMER("update");
if (localUpdater_) {
localUpdater_->update(para);
if (!needToUpdateRemotely()) {
return;
}
}
sendQueue_.enqueue(nonStaticParaIDMap_[para->getID()]);
}
void ConcurrentRemoteParameterUpdater::copySingleParaToDevice(
Parameter* para, ParameterType parameterType) {
if (!FLAGS_use_gpu) {
return;
}
int i = nonStaticParaIDMap_[para->getID()];
para->getBuf(parameterType)
->copyFrom(*cpuParameters_[i]->getBuf(parameterType),
kHostToDeviceStream);
if (parameterType == PARAMETER_VALUE) {
para->setValueUpdated();
}
}
void ConcurrentRemoteParameterUpdater::copySingleParaFromDevice(
Parameter* para, ParameterType parameterType) {
if (!FLAGS_use_gpu) {
return;
}
int i = nonStaticParaIDMap_[para->getID()];
cpuParameters_[i]
->getBuf(parameterType)
->copyFrom(*para->getBuf(parameterType), kDeviceToHostStream);
}
SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
const OptimizationConfig& config, int expectedPassCount, bool testing)
: config_(config),
passCount_(0),
expectedPassCount_(expectedPassCount),
testing_(testing),
useApplyInPserver_(false) {}
void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
parameterClient_.reset(new ParameterClient2(false,
FLAGS_port + FLAGS_ports_num, FLAGS_ports_num_for_sparse));
parameterClient_->init(parameters_);
parameterClient_->setTrainerId(FLAGS_trainer_id);
if (FLAGS_trainer_id == 0) {
parameterClient_->setConfig(config_, FLAGS_save_dir,
true /*is_sparse_server*/);
if (parameters[0]->isFullSize()) {
parameterClient_->setParameter();
} else { // init in pserver
parameterClient_->setParameterZero();
}
}
if (FLAGS_trainer_id == 0 && !testing_ &&
config_.algorithm() == TrainAlgorithm::SGD) {
startController();
useApplyInPserver_ = useApplyInPserver(config_);
}
}
void SparseRemoteParameterUpdater::startController() {
controllerThread_.reset(new std::thread([this]() { this->controller(); }));
}
void SparseRemoteParameterUpdater::controller() {
ParameterClient2 client(false,
FLAGS_port + FLAGS_ports_num,
FLAGS_ports_num_for_sparse);
client.init(parameters_);
while (true) {
/*start pass*/ {
client.waitPassStart();
PreparedOperations ops;
ops.addOperation(PSERVER_OP_START_PASS);
client.doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false,
/* releasePass= */ false);
}
while (true) {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_SGD);
client.doOperation(ops,
/* waitForGradient= */ true,
/* sendBackarameter= */ true,
/* releasePass= */ false);
if (client.isPassFinish()) {
break;
}
}
/*finish pass*/ {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_FINISH_PASS);
client.doOperation(ops,
/* waitForGradient= */ true,
/* sendBackarameter= */ true,
/* releasePass= */ true);
}
passCount_++;
if (passCount_ == expectedPassCount_) {
break;
}
}
}
PassType SparseRemoteParameterUpdater::startBatch(int64_t batchSize) {
batchSize_ = batchSize;
return PASS_TRAIN;
}
void SparseRemoteParameterUpdater::finishBatch(real cost) {
const std::string& algorithm = config_.algorithm();
ParameterUpdateMode mode;
if (algorithm == TrainAlgorithm::AsyncSGD) {
mode = PSERVER_UPDATE_MODE_ASYNC_SGD;
} else if (algorithm == TrainAlgorithm::SGD) {
mode = PSERVER_UPDATE_MODE_ADD_GRADIENT;
} else {
LOG(FATAL) << "Unknown algorithm: " << algorithm;
}
ParameterType sendType = PARAMETER_GRADIENT;
REGISTER_TIMER("sendSparseParam");
parameterClient_->sendAndReceiveParameter(mode, sendType, batchSize_,
0, // cost = 0
false); // sendBackParameter
// grad zero move to sgd grad machine, before merge grad sparse remote
}
void SparseRemoteParameterUpdater::startPass() {
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassStart();
} else {
if (FLAGS_trainer_id == 0) {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_START_PASS);
parameterClient_->doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false);
}
parameterClient_->asyncStartPass();
}
}
bool SparseRemoteParameterUpdater::finishPass(real cost) {
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassFinish();
} else {
if (FLAGS_trainer_id == 0) {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_FINISH_PASS);
parameterClient_->doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false);
}
parameterClient_->asyncFinishPass();
}
return true;
}
// Trainer will call getParametersRemote at batch start or before save,
// so we do not get values in apply() and restore().
void SparseRemoteParameterUpdater::apply() {
if (useApplyInPserver_) {
PreparedOperations ops;
ops.addOperation(PSERVER_OP_APPLY);
parameterClient_->doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false);
}
}
void SparseRemoteParameterUpdater::restore() {}
void SparseRemoteParameterUpdater::getParametersRemote(bool fullSize,
bool apply) {
ParameterType sendBackParameterType =
(useApplyInPserver_ && apply) ? PARAMETER_APPLY : PARAMETER_VALUE;
if (fullSize) {
parameterClient_->getParameter(
/* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType);
if (config_.shrink_parameter_value() > 0) {
for (auto& para : parameters_) {
if (para->getConfig().decay_rate_l1() > 0) {
para->getBuf(PARAMETER_VALUE)
->applyL1(1.0f, // learningRate
config_.shrink_parameter_value()); // decayRate
}
}
}
} else {
REGISTER_TIMER("getParamSparse");
parameterClient_->getParameterSparse(
/* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType);
if (config_.shrink_parameter_value() > 0) {
for (auto& para : parameters_) {
if (para->getConfig().decay_rate_l1() > 0) {
para->getPrefetchMatrix()->applyL1Decay(
1.0f, // learningRate
config_.shrink_parameter_value()); // decayRate
}
}
}
}
}
void SparseRemoteParameterUpdater::randParametersRemote() {
CHECK_EQ(FLAGS_trainer_id, 0);
PreparedOperations ops;
ops.addOperation(PSERVER_OP_RANDOMIZE);
parameterClient_->doOperation(ops,
/* waitForGradient= */ false,
/* sendBackarameter= */ false);
}
void SparseRemoteParameterUpdater::loadParametersRemote(
const std::string& dirName) {
if (FLAGS_trainer_id == 0) {
parameterClient_->loadValueVector(dirName);
}
if (testing_) {
// we do not use synchronize() here,
// because test mode may run only one tester
if (FLAGS_trainer_id == 0) {
parameterClient_->setStatus(PSERVER_STATUS_PARAMETER_READY);
} else {
parameterClient_->waitForStatus(PSERVER_STATUS_PARAMETER_READY);
}
}
}
void SparseRemoteParameterUpdater::saveParametersRemote(
const std::string& dirName) {
if (FLAGS_trainer_id == 0) {
parameterClient_->saveValueVector(dirName);
}
}
void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;
std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
for (auto& para : parameters_) {
if (para->isSparseRemoteUpdate()) {
parametersArray[UPDATER_SPARSE_REMOTE].push_back(para);
} else {
parametersArray[UPDATER_NORMAL].push_back(para);
}
}
CHECK(!parametersArray[UPDATER_SPARSE_REMOTE].empty());
CHECK(!parametersArray[UPDATER_NORMAL].empty());
syncThreadPool_->execPlusOwner([&](int tid, size_t numThreads) {
updaters_[tid]->init(parametersArray[tid]);
});
parameterTypes_ = updaters_[UPDATER_NORMAL]->getParameterTypes();
}
std::vector<std::function<ParameterUpdater*(
const std::string&, const OptimizationConfig&, bool, size_t)>>
ParameterUpdaterCreators::constructors_;
} // namespace paddle