|
|
|
@ -110,43 +110,10 @@ void NewRemoteParameterUpdater::init(
|
|
|
|
|
|
|
|
|
|
// overwrite optimizerConfigV2 for per-parameter(layer) configs
|
|
|
|
|
for (int i = 0; i < parameterSize(); ++i) {
|
|
|
|
|
auto paramConfig = parameters_[i]->getConfig();
|
|
|
|
|
if (paramConfig.has_momentum() &&
|
|
|
|
|
trainerConfig_.learning_method() == "momentum") {
|
|
|
|
|
optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
|
|
|
|
|
}
|
|
|
|
|
if (paramConfig.has_learning_rate()) {
|
|
|
|
|
switch (optimizerConfigV2.lr_policy()) {
|
|
|
|
|
case 0:
|
|
|
|
|
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
|
|
|
|
|
paramConfig.learning_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 1:
|
|
|
|
|
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
|
|
|
|
|
paramConfig.learning_rate());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (paramConfig.has_decay_rate()) {
|
|
|
|
|
switch (optimizerConfigV2.optimizer()) {
|
|
|
|
|
case 1: // SGD
|
|
|
|
|
optimizerConfigV2.mutable_sgd()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 2: // Adadelta
|
|
|
|
|
optimizerConfigV2.mutable_adadelta()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 3: // Adagrad
|
|
|
|
|
optimizerConfigV2.mutable_adagrad()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
case 4: // Adam
|
|
|
|
|
optimizerConfigV2.mutable_adam()->set_decay(
|
|
|
|
|
paramConfig.decay_rate());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// FIXME(typhoonzero): paramConfig always have default values,
|
|
|
|
|
// how to check if it's default?
|
|
|
|
|
// TODO(typhoonzero): log output: optimizerConfigV2.DebugString();
|
|
|
|
|
LOG(INFO) << "trainerConfig_: " << trainerConfig_.DebugString();
|
|
|
|
|
// send param and config to pserver
|
|
|
|
|
std::string bytes = optimizerConfigV2.SerializeAsString();
|
|
|
|
|
const char *array = bytes.data();
|
|
|
|
|