|
|
|
@ -45,11 +45,9 @@ public:
|
|
|
|
config_.mutable_sgd()->set_nesterov(false);
|
|
|
|
config_.mutable_sgd()->set_nesterov(false);
|
|
|
|
config_.set_lr_policy(OptimizerConfig::Const);
|
|
|
|
config_.set_lr_policy(OptimizerConfig::Const);
|
|
|
|
config_.mutable_const_lr()->set_learning_rate(0.1);
|
|
|
|
config_.mutable_const_lr()->set_learning_rate(0.1);
|
|
|
|
|
|
|
|
|
|
|
|
std::string str = config_.SerializeAsString();
|
|
|
|
std::string str = config_.SerializeAsString();
|
|
|
|
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
|
|
|
|
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
|
|
|
|
opts_.push_back(opt);
|
|
|
|
opts_.push_back(opt);
|
|
|
|
opts_table_[opts_.size()] = OptimizerConfig::SGD;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CreateAdam() {
|
|
|
|
void CreateAdam() {
|
|
|
|
@ -64,7 +62,6 @@ public:
|
|
|
|
std::string str = config_.SerializeAsString();
|
|
|
|
std::string str = config_.SerializeAsString();
|
|
|
|
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
|
|
|
|
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
|
|
|
|
opts_.push_back(opt);
|
|
|
|
opts_.push_back(opt);
|
|
|
|
opts_table_[opts_.size()] = OptimizerConfig::Adam;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void TestGetWeight() {
|
|
|
|
void TestGetWeight() {
|
|
|
|
@ -86,21 +83,15 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void TestCheckPoint() {
|
|
|
|
void TestCheckPoint() {
|
|
|
|
std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
|
|
|
|
|
|
|
|
{OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
|
|
|
|
|
|
|
|
{OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
for (size_t i = 0; i < opts_.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < opts_.size(); ++i) {
|
|
|
|
int state_len = 0;
|
|
|
|
int state_len = 0;
|
|
|
|
std::string state = opts_[i]->SerializeState(&state_len);
|
|
|
|
std::string state = opts_[i]->SerializeState(&state_len);
|
|
|
|
EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
|
|
|
|
|
|
|
|
opts_[i]->DeserializeState(state);
|
|
|
|
opts_[i]->DeserializeState(state);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
std::vector<ParameterOptimizer*> opts_;
|
|
|
|
std::vector<ParameterOptimizer*> opts_;
|
|
|
|
std::map<int, OptimizerConfig::Optimizer> opts_table_;
|
|
|
|
|
|
|
|
OptimizerConfig config_;
|
|
|
|
OptimizerConfig config_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|