You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
4.7 KiB
165 lines
4.7 KiB
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
syntax = "proto2";
|
|
|
|
option optimize_for = LITE_RUNTIME;
|
|
|
|
package paddle;
|
|
|
|
message SGDConfig {
|
|
// SGD
|
|
// momentum: float >= 0. Parameter updates momentum.
|
|
// decay: float >= 0. Learning rate decay over each update.
|
|
// nesterov: boolean. Whether to apply Nesterov momentum.
|
|
optional double momentum = 21 [ default = 0.0 ];
|
|
optional double decay = 23 [ default = 0.0 ];
|
|
optional bool nesterov = 24 [ default = false ];
|
|
}
|
|
|
|
message AdadeltaConfig {
|
|
// Adadelta
|
|
// It is recommended to leave it at the default value.
|
|
// rho: float >= 0.
|
|
// epsilon: float >= 0. Fuzz factor.
|
|
// decay: float >= 0. Learning rate decay over each update.
|
|
|
|
// reference : [Adadelta - an adaptive learning rate
|
|
// method](http://arxiv.org/abs/1212.5701)
|
|
optional double rho = 33 [ default = 0.90 ];
|
|
optional double epsilon = 31 [ default = 1e-5 ];
|
|
optional double decay = 32 [ default = 0.0 ];
|
|
}
|
|
|
|
message AdagradConfig {
|
|
// Adagrad
|
|
// epsilon: float >= 0.
|
|
// decay: float >= 0. Learning rate decay over each update.
|
|
|
|
// reference : [Adaptive Subgradient Methods for Online Learning and
|
|
// Stochastic
|
|
// Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
|
|
optional double epsilon = 41 [ default = 1e-5 ];
|
|
optional double decay = 42 [ default = 0.0 ];
|
|
}
|
|
|
|
message AdamConfig {
|
|
// Adaj
|
|
// beta_1: float, 0 < beta < 1. Generally close to 1.
|
|
// beta_2: float, 0 < beta < 1. Generally close to 1.
|
|
// epsilon: float >= 0. Fuzz factor.
|
|
// decay: float >= 0. Learning rate decay over each update.
|
|
// reference : [Adam - A Method for Stochastic
|
|
// Optimization](http://arxiv.org/abs/1412.6980v8)
|
|
optional double beta_1 = 41;
|
|
optional double beta_2 = 42;
|
|
optional double epsilon = 43;
|
|
optional double decay = 44;
|
|
}
|
|
|
|
message ConstLrConfig {
|
|
// learninRate Policy
|
|
optional double learning_rate = 1 [ default = 1.0 ];
|
|
}
|
|
|
|
message LinearLrConfig {
|
|
// learninRate Policy
|
|
optional double learning_rate = 1 [ default = 1.0 ];
|
|
optional double lr_decay_a = 2;
|
|
optional double lr_decay_b = 3;
|
|
}
|
|
|
|
message TensorProto {
|
|
enum DataType {
|
|
PADDLE_ELEMENT_TYPE_INT32 = 0;
|
|
PADDLE_ELEMENT_TYPE_UINT32 = 1;
|
|
PADDLE_ELEMENT_TYPE_INT64 = 2;
|
|
PADDLE_ELEMENT_TYPE_UINT64 = 3;
|
|
PADDLE_ELEMENT_TYPE_FLOAT32 = 4;
|
|
PADDLE_ELEMENT_TYPE_FLOAT64 = 5;
|
|
}
|
|
optional DataType data_type = 1;
|
|
repeated bytes content = 2;
|
|
}
|
|
|
|
message LrPolicyState {
|
|
// learninRate Policy
|
|
optional double learning_rate = 1 [ default = 1.0 ];
|
|
optional double lr_decay_a = 2;
|
|
optional double lr_decay_b = 3;
|
|
}
|
|
|
|
message SGDOptimizerState {
|
|
optional LrPolicyState lr_state = 101;
|
|
optional double num_sample_passed = 104;
|
|
// state
|
|
optional TensorProto parameter = 1;
|
|
optional TensorProto momentums = 2;
|
|
}
|
|
|
|
message AdadeltaOptimizerState {
|
|
// learning rate policy
|
|
optional LrPolicyState lr_state = 101;
|
|
optional double num_sample_passed = 104;
|
|
// state
|
|
optional TensorProto parameter = 1;
|
|
optional TensorProto accum_gradient = 2;
|
|
optional TensorProto accum_delta = 3;
|
|
optional TensorProto update_delta = 4;
|
|
}
|
|
|
|
message AdagradOptimizerState {
|
|
optional LrPolicyState lr_state = 101;
|
|
optional double num_sample_passed = 104;
|
|
// state
|
|
optional TensorProto parameter = 1;
|
|
optional TensorProto accum_gradient = 2;
|
|
}
|
|
|
|
message AdamOptimizerState {
|
|
optional LrPolicyState lr_state = 101;
|
|
optional double num_sample_passed = 104;
|
|
// state
|
|
optional TensorProto parameter = 1;
|
|
optional TensorProto momentums = 2;
|
|
optional TensorProto velocitys = 3;
|
|
}
|
|
|
|
message OptimizerConfig {
|
|
enum Optimizer {
|
|
SGD = 1;
|
|
Adadelta = 2;
|
|
Adagrad = 3;
|
|
Adam = 4;
|
|
}
|
|
optional Optimizer optimizer = 1;
|
|
optional SGDConfig sgd = 3;
|
|
optional AdadeltaConfig adadelta = 4;
|
|
optional AdagradConfig adagrad = 5;
|
|
optional AdamConfig adam = 6;
|
|
|
|
enum LrPolicy {
|
|
Const = 0;
|
|
Linear = 1;
|
|
}
|
|
optional LrPolicy lr_policy = 11;
|
|
optional ConstLrConfig const_lr = 12;
|
|
optional LinearLrConfig linear_lr = 13;
|
|
|
|
// common config of optimizer
|
|
// gradient clip when L2 exceeding value
|
|
optional double clip_norm = 101;
|
|
// gradient clip when L1 exceeding value
|
|
optional double clip_value = 102;
|
|
}
|