You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
1.8 KiB
72 lines
1.8 KiB
8 years ago
|
#include "parameter_optimizer.h"
|
||
|
#include <glog/logging.h>
|
||
|
#include "optimizer_factory.h"
|
||
|
|
||
|
namespace paddle {
|
||
|
namespace optimizer {
|
||
|
|
||
|
template <class T>
|
||
|
ParameterOptimizer<T> *ParameterOptimizer<T>::create(
|
||
|
const ::std::string &config_proto) {
|
||
|
paddle::OptimizerConfig config;
|
||
|
CHECK(config.ParseFromString(config_proto) == 0)
|
||
|
<< "error : optimizer config";
|
||
|
CHECK(config_valid(config) == 0) << "error : invalid optimizer config ";
|
||
|
ParameterOptimizer<T> *opt = nullptr;
|
||
|
switch (config.optimizer_name()) {
|
||
|
case "SGD":
|
||
|
opt = new SGDOptimizer<T>(config);
|
||
|
break;
|
||
|
case "Adagrad":
|
||
|
opt = new AdagradOptimizer<T>(config);
|
||
|
break;
|
||
|
case "Adadelta":
|
||
|
opt = new AdadeltaOptimizer<T>(config);
|
||
|
break;
|
||
|
case "Adam":
|
||
|
opt = new AdamOptimizer<T>(config);
|
||
|
break;
|
||
|
default:
|
||
|
opt = new SGDOptimizer<T>(config);
|
||
|
}
|
||
|
|
||
|
switch (config.lr_policy()) {
|
||
|
case "ConstLr":
|
||
|
opt.lr_policy = new ConstLr(config);
|
||
|
break;
|
||
|
}
|
||
|
return opt;
|
||
|
}
|
||
|
|
||
|
template <class T>
|
||
|
T *ParameterOptimizer<T>::get_weight() const {
|
||
|
return parameter.get().get_buffer();
|
||
|
}
|
||
|
|
||
|
template <class T>
|
||
|
char *ParameterOptimizer<T>::get_config_proto() const {
|
||
|
// set config dynamic value for save checkpoint
|
||
|
config_.lr_policy().set_learning_rate(
|
||
|
lr_policy->get_learning_rate(num_sample_passed));
|
||
|
config_.set_num_sample_passed(num_sample_passed);
|
||
|
config_.set_iterations(iterations);
|
||
|
return config_.SerializeAsString().c_str();
|
||
|
}
|
||
|
|
||
|
template <class T>
|
||
|
void ParameterOptimizer<T>::set_weight(const Tensor<T> *p) {
|
||
|
parameter_ = p;
|
||
|
}
|
||
|
|
||
|
template <class T>
|
||
|
bool ParameterOptimizer<T>::config_valid(const ::std::string &config) const {
|
||
|
// TODO(zhihong) : add more value checker, failed ASAP
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
template class ParameterOptimzier<float>;
|
||
|
template class ParameterOptimzier<double>;
|
||
|
|
||
|
} // namespace optimizer
|
||
|
} // namespace paddle
|