|
|
|
@ -2,11 +2,8 @@
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "adadelta_optimizer.h"
|
|
|
|
|
#include "adagrad_optimizer.h"
|
|
|
|
|
#include "adam_optimizer.h"
|
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
|
#include "sgd_optimizer.h"
|
|
|
|
|
#include "lr_policy.h"
|
|
|
|
|
|
|
|
|
|
using namespace paddle;
|
|
|
|
|
using namespace paddle::optimizer;
|
|
|
|
@ -41,12 +38,12 @@ public:
|
|
|
|
|
virtual void TearDown() {}
|
|
|
|
|
|
|
|
|
|
void CreateSGD() {
|
|
|
|
|
Tensor* parameter = FillTensor(kSize);
|
|
|
|
|
Tensor* parameter = FixedTensor(kSize);
|
|
|
|
|
config_.set_optimizer(OptimizerConfig::SGD);
|
|
|
|
|
config_.mutable_sgd()->set_momentum(0.0);
|
|
|
|
|
config_.mutable_sgd()->set_decay(0.0);
|
|
|
|
|
config_.mutable_sgd()->set_nesterov(false);
|
|
|
|
|
config_.set_lr_policy(OptimizerConfig::ConstLr);
|
|
|
|
|
config_.set_lr_policy(OptimizerConfig::Const);
|
|
|
|
|
config_.mutable_const_lr()->set_learning_rate(0.1);
|
|
|
|
|
|
|
|
|
|
std::string str = config_.SerializeAsString();
|
|
|
|
@ -62,7 +59,7 @@ public:
|
|
|
|
|
config_.mutable_adam()->set_beta_2(0.1);
|
|
|
|
|
config_.mutable_adam()->set_epsilon(1e-3);
|
|
|
|
|
config_.mutable_adam()->set_decay(0.0);
|
|
|
|
|
config_.set_lr_policy(OptimizerConfig::ConstLr);
|
|
|
|
|
config_.set_lr_policy(OptimizerConfig::Const);
|
|
|
|
|
config_.mutable_const_lr()->set_learning_rate(0.1);
|
|
|
|
|
std::string str = config_.SerializeAsString();
|
|
|
|
|
ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter);
|
|
|
|
@ -90,12 +87,13 @@ public:
|
|
|
|
|
|
|
|
|
|
void TestCheckPoint() {
|
|
|
|
|
std::map<OptimizerConfig::Optimizer, int> expected_state_len = {
|
|
|
|
|
{OptimizerConfig::SGD, kSize}, {OptimizerConfig::Adam, kSize * 3},
|
|
|
|
|
{OptimizerConfig::SGD, kSize * sizeof(float) + sizeof(double)},
|
|
|
|
|
{OptimizerConfig::Adam, kSize * 3 * sizeof(float) + sizeof(double)},
|
|
|
|
|
};
|
|
|
|
|
for (size_t i = 0; i < opts_.size(); ++i) {
|
|
|
|
|
int state_len = 0;
|
|
|
|
|
std::string state = opts_[i]->SerializeState(&state_len);
|
|
|
|
|
EXPECT_EQ(state_len, expected_state_len[opts_table_[i]]);
|
|
|
|
|
EXPECT_EQ(state_len, expected_state_len[opts_table_[i + 1]]);
|
|
|
|
|
opts_[i]->DeserializeState(state);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|