|
|
|
@ -22,7 +22,8 @@ DECLARE_string(save_dir);
|
|
|
|
|
namespace paddle {
|
|
|
|
|
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
|
|
|
|
|
const OptimizationConfig &config, const std::string pserverSpec)
|
|
|
|
|
: parameterClient_(-1),
|
|
|
|
|
: trainerConfig_(config),
|
|
|
|
|
parameterClient_(-1),
|
|
|
|
|
newParameters_(nullptr),
|
|
|
|
|
newGradients_(nullptr),
|
|
|
|
|
pserverSpec_(pserverSpec) {}
|
|
|
|
@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init(
|
|
|
|
|
LOG(INFO) << "paddle_begin_init_params start";
|
|
|
|
|
for (int i = 0; i < parameterSize(); ++i) {
|
|
|
|
|
auto paramConfig = parameters_[i]->getConfig();
|
|
|
|
|
std::string bytes = paramConfig.SerializeAsString();
|
|
|
|
|
LOG(INFO) << "old param config: " << paramConfig.DebugString();
|
|
|
|
|
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
|
|
|
|
|
OptimizerConfig optimizeConfigV2;
|
|
|
|
|
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
|
|
|
|
|
sgdConfigV2->set_momentum(paramConfig.momentum());
|
|
|
|
|
sgdConfigV2->set_decay(paramConfig.decay_rate());
|
|
|
|
|
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
|
|
|
|
|
auto constlr = optimizeConfigV2.mutable_const_lr();
|
|
|
|
|
constlr->set_learning_rate(paramConfig.learning_rate());
|
|
|
|
|
if (trainerConfig_.algorithm() == "sgd") {
|
|
|
|
|
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
|
|
|
|
|
// FIXME: config all algorithms
|
|
|
|
|
} else {
|
|
|
|
|
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
|
|
|
|
|
}
|
|
|
|
|
std::string bytes = optimizeConfigV2.SerializeAsString();
|
|
|
|
|
const char *array = bytes.data();
|
|
|
|
|
int size = (int)bytes.size();
|
|
|
|
|
paddle_init_param(
|
|
|
|
@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
|
|
|
|
|
void NewRemoteParameterUpdater::startPass() {}
|
|
|
|
|
|
|
|
|
|
bool NewRemoteParameterUpdater::finishPass() { return true; }
|
|
|
|
|
}
|
|
|
|
|
} // namespace paddle
|
|
|
|
|