|
|
|
@ -30,7 +30,11 @@ void SGDOptimizer::Update(const Tensor *gradient) {
|
|
|
|
|
const char *SGDOptimizer::SerializeState(int *state_len) {
|
|
|
|
|
SGDOptimizerState state;
|
|
|
|
|
state.set_num_sample_passed(num_sample_passed_);
|
|
|
|
|
state.set_lr_ TensorToProto(*parameter_, state.mutable_parameter());
|
|
|
|
|
std::string lr_str = this->lr_policy_->SerializeState(state_len);
|
|
|
|
|
LrPolicyState lr_state;
|
|
|
|
|
lr_state.ParseFromString(lr_str);
|
|
|
|
|
state.mutable_lr_state() = lr_state;
|
|
|
|
|
TensorToProto(*parameter_, state.mutable_parameter());
|
|
|
|
|
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
|
|
|
|
|
auto str = state.SerializeAsString();
|
|
|
|
|
*state_len += str.size();
|
|
|
|
|