|
|
|
@ -66,28 +66,92 @@ void NewRemoteParameterUpdater::init(
|
|
|
|
|
// from parameter server
|
|
|
|
|
if (paddle_begin_init_params(parameterClient_)) {
|
|
|
|
|
LOG(INFO) << "paddle_begin_init_params start";
|
|
|
|
|
// NOTE: convert V1 OptimizatioinConfig proto to V2 OptimizerConfig.
|
|
|
|
|
// This makes golang pserver compatible with handy V1 demos.
|
|
|
|
|
// TODO: Refine or remove these ugly converting lines
|
|
|
|
|
OptimizerConfig optimizerConfigV2;
|
|
|
|
|
if (trainerConfig_.learning_method() == "momentum") {
|
|
|
|
|
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
|
|
|
|
|
} else if (trainerConfig_.learning_method() == "adagrad") {
|
|
|
|
|
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
|
|
|
|
|
optimizerConfigV2.mutable_adagrad()->set_epsilon(
|
|
|
|
|
trainerConfig_.ada_epsilon());
|
|
|
|
|
} else if (trainerConfig_.learning_method() == "adadelta") {
|
|
|
|
|
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
|
|
|
|
|
optimizerConfigV2.mutable_adadelta()->set_epsilon(
|
|
|
|
|
trainerConfig_.ada_epsilon());
|
|
|
|
|
optimizerConfigV2.mutable_adadelta()->set_rho(trainerConfig_.ada_rou());
|
|
|
|
|
} else if (trainerConfig_.learning_method() == "adam") {
|
|
|
|
|
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adam);
|
|
|
|
|
optimizerConfigV2.mutable_adam()->set_beta_1(trainerConfig_.adam_beta1());
|
|
|
|
|
optimizerConfigV2.mutable_adam()->set_beta_2(trainerConfig_.adam_beta2());
|
|
|
|
|
optimizerConfigV2.mutable_adam()->set_epsilon(
|
|
|
|
|
trainerConfig_.adam_epsilon());
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "got unsupported v1 optimizer config: "
|
|
|
|
|
<< trainerConfig_.learning_method();
|
|
|
|
|
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (trainerConfig_.learning_rate_schedule() == "constant") {
|
|
|
|
|
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
|
|
|
|
|
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
|
|
|
|
|
trainerConfig_.learning_rate());
|
|
|
|
|
} else if (trainerConfig_.learning_rate_schedule() == "linear") {
|
|
|
|
|
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Linear);
|
|
|
|
|
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
|
|
|
|
|
trainerConfig_.learning_rate());
|
|
|
|
|
optimizerConfigV2.mutable_linear_lr()->set_lr_decay_a(
|
|
|
|
|
trainerConfig_.learning_rate_decay_a());
|
|
|
|
|
optimizerConfigV2.mutable_linear_lr()->set_lr_decay_b(
|
|
|
|
|
trainerConfig_.learning_rate_decay_b());
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
|
|
|
|
|
<< trainerConfig_.learning_rate_schedule() << ", set to const";
|
|
|
|
|
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// overwrite optimizerConfigV2 for per-parameter(layer) configs
|
|
|
|
|
for (int i = 0; i < parameterSize(); ++i) {
|
|
|
|
|
auto paramConfig = parameters_[i]->getConfig();
|
|
|
|
|
LOG(INFO) << "old param config: " << paramConfig.DebugString();
|
|
|
|
|
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
|
|
|
|
|
OptimizerConfig optimizeConfigV2;
|
|
|
|
|
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
|
|
|
|
|
sgdConfigV2->set_momentum(paramConfig.momentum());
|
|
|
|
|
sgdConfigV2->set_decay(paramConfig.decay_rate());
|
|
|
|
|
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
|
|
|
|
|
auto constlr = optimizeConfigV2.mutable_const_lr();
|
|
|
|
|
if (paramConfig.has_momentum() &&
|
|
|
|
|
trainerConfig_.learning_method() == "momentum") {
|
|
|
|
|
optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
|
|
|
|
|
}
|
|
|
|
|
if (paramConfig.has_learning_rate()) {
|
|
|
|
|
constlr->set_learning_rate(paramConfig.learning_rate());
|
|
|
|
|
} else {
|
|
|
|
|
constlr->set_learning_rate(trainerConfig_.learning_rate());
|
|
|
|
|
switch (optimizerConfigV2.lr_policy()) {
|
|
|
|
|
case 0:
|
|
|
|
|
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
|
|
|
|
|
paramConfig.learning_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 1:
|
|
|
|
|
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
|
|
|
|
|
paramConfig.learning_rate());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (trainerConfig_.algorithm() == "sgd") {
|
|
|
|
|
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
|
|
|
|
|
// FIXME: config all algorithms
|
|
|
|
|
} else {
|
|
|
|
|
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
|
|
|
|
|
if (paramConfig.has_decay_rate()) {
|
|
|
|
|
switch (optimizerConfigV2.optimizer()) {
|
|
|
|
|
case 1: // SGD
|
|
|
|
|
optimizerConfigV2.mutable_sgd()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 2: // Adadelta
|
|
|
|
|
optimizerConfigV2.mutable_adadelta()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 3: // Adagrad
|
|
|
|
|
optimizerConfigV2.mutable_adagrad()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 4: // Adam
|
|
|
|
|
optimizerConfigV2.mutable_adam()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::string bytes = optimizeConfigV2.SerializeAsString();
|
|
|
|
|
// send param and config to pserver
|
|
|
|
|
std::string bytes = optimizerConfigV2.SerializeAsString();
|
|
|
|
|
const char *array = bytes.data();
|
|
|
|
|
int size = (int)bytes.size();
|
|
|
|
|
paddle_init_param(
|
|
|
|
|