"polish name convention"

gangliao-patch-1
dongzhihong 8 years ago
parent e1acd73fab
commit 7edabe74d4

@ -33,7 +33,7 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
std::string lr_str = this->lr_policy_->SerializeState(state_len);
LrPolicyState lr_state;
lr_state.ParseFromString(lr_str);
state.mutable_lr_state() = lr_state;
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString();
@ -44,6 +44,8 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state;
state.ParseFromString(str);
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);

@ -86,7 +86,7 @@ message LrPolicyState {
}
message SGDOptimizerState {
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
@ -106,7 +106,7 @@ message AdadeltaOptimizerState {
message AdagradOptimizerState {
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
@ -114,7 +114,7 @@ message AdagradOptimizerState {
}
message AdamOptimizerState {
optional LrPolicyState lrstate = 101;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;

Loading…
Cancel
Save