|
|
|
@ -44,7 +44,7 @@ void SGDOptimizer::DeserializeState(const std::string &str) {
|
|
|
|
|
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
|
|
|
|
|
num_sample_passed_ = state.num_sample_passed();
|
|
|
|
|
ProtoToTensor(state.parameter(), parameter_);
|
|
|
|
|
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
|
|
|
|
|
if (momentum_ != 0.0) ProtoToTensor(state.momentums(), momentums_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace optimizer
|
|
|
|
|