You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/optimizer/parameter_optimizer.cc

72 lines
1.8 KiB

#include "parameter_optimizer.h"
#include <glog/logging.h>
#include "optimizer_factory.h"
namespace paddle {
namespace optimizer {
template <class T>
ParameterOptimizer<T> *ParameterOptimizer<T>::create(
const ::std::string &config_proto) {
paddle::OptimizerConfig config;
CHECK(config.ParseFromString(config_proto) == 0)
<< "error : optimizer config";
CHECK(config_valid(config) == 0) << "error : invalid optimizer config ";
ParameterOptimizer<T> *opt = nullptr;
switch (config.optimizer_name()) {
case "SGD":
opt = new SGDOptimizer<T>(config);
break;
case "Adagrad":
opt = new AdagradOptimizer<T>(config);
break;
case "Adadelta":
opt = new AdadeltaOptimizer<T>(config);
break;
case "Adam":
opt = new AdamOptimizer<T>(config);
break;
default:
opt = new SGDOptimizer<T>(config);
}
switch (config.lr_policy()) {
case "ConstLr":
opt.lr_policy = new ConstLr(config);
break;
}
return opt;
}
template <class T>
T *ParameterOptimizer<T>::get_weight() const {
return parameter.get().get_buffer();
}
template <class T>
char *ParameterOptimizer<T>::get_config_proto() const {
// set config dynamic value for save checkpoint
config_.lr_policy().set_learning_rate(
lr_policy->get_learning_rate(num_sample_passed));
config_.set_num_sample_passed(num_sample_passed);
config_.set_iterations(iterations);
return config_.SerializeAsString().c_str();
}
template <class T>
void ParameterOptimizer<T>::set_weight(const Tensor<T> *p) {
parameter_ = p;
}
template <class T>
bool ParameterOptimizer<T>::config_valid(const ::std::string &config) const {
// TODO(zhihong) : add more value checker, failed ASAP
return true;
}
template class ParameterOptimzier<float>;
template class ParameterOptimzier<double>;
} // namespace optimizer
} // namespace paddle