"fix parameter accumulate size"

gangliao-patch-1
dongzhihong 8 years ago
parent 7edabe74d4
commit dec65aca7d

@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const char* AdadeltaOptimizer::SerializeState(int* state_len) { const char* AdadeltaOptimizer::SerializeState(int* state_len) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta()); TensorToProto(*update_delta_, state.mutable_update_delta());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdadeltaOptimizer::DeserializeState(const std::string& str) { void AdadeltaOptimizer::DeserializeState(const std::string& str) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);

@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
} }
const char* AdagradOptimizer::SerializeState(int* state_len) { const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state; AdagradOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdagradOptimizer::DeserializeState(const std::string& str) { void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state; AdagradOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_); ProtoToTensor(state.accum_gradient(), accum_gradient_);

@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const char *AdamOptimizer::SerializeState(int *state_len) { const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state; AdamOptimizerState state;
// TODO(zhihong) : add lr_policy serialization std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys()); TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdamOptimizer::DeserializeState(const std::string &str) { void AdamOptimizer::DeserializeState(const std::string &str) {
AdamOptimizerState state; AdamOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);

@ -31,8 +31,6 @@ const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state; SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len); std::string lr_str = this->lr_policy_->SerializeState(state_len);
LrPolicyState lr_state;
lr_state.ParseFromString(lr_str);
state.mutable_lr_state()->ParseFromString(lr_str); state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());

Loading…
Cancel
Save