|
|
|
@ -28,10 +28,12 @@ from . import ops
|
|
|
|
|
from . import tensor
|
|
|
|
|
from ..initializer import init_on_cpu
|
|
|
|
|
from ..framework import default_main_program, Parameter, unique_name, name_scope
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
|
|
|
|
|
'polynomial_decay', 'piecewise_decay', 'noam_decay', 'append_LARS'
|
|
|
|
|
'polynomial_decay', 'piecewise_decay', 'noam_decay', 'append_LARS',
|
|
|
|
|
'cosine_decay'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -307,6 +309,39 @@ def piecewise_decay(boundaries, values):
|
|
|
|
|
return lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cosine_decay(learning_rate, step_each_epoch, epochs):
|
|
|
|
|
"""
|
|
|
|
|
Applies cosine decay to the learning rate.
|
|
|
|
|
|
|
|
|
|
when training a model, it is oftem recommended to lower the learning rate as the
|
|
|
|
|
training progresses. By using this function, the learning rate will be decayed by
|
|
|
|
|
following cosine decay strategy.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
learning_rate(Variable|float): The initial learning rate.
|
|
|
|
|
step_each_epoch(int): the number of steps in an epoch.
|
|
|
|
|
epochs(int): the number of epochs.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable: The decayed learning rate.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
..code-block:: python
|
|
|
|
|
|
|
|
|
|
base_lr = 0.1
|
|
|
|
|
lr = fluid.layers.cosine_decay(
|
|
|
|
|
learning_rate = base_lr, step_each_epoch=10000, epochs=120)
|
|
|
|
|
"""
|
|
|
|
|
with default_main_program()._lr_schedule_guard():
|
|
|
|
|
global_step = _decay_step_counter()
|
|
|
|
|
|
|
|
|
|
cur_epoch = ops.floor(global_step / step_each_epoch)
|
|
|
|
|
decayed_lr = learning_rate * 0.5 * (
|
|
|
|
|
ops.cos(cur_epoch * math.pi / epochs) + 1)
|
|
|
|
|
return decayed_lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_LARS(params_grads, learning_rate, weight_decay):
|
|
|
|
|
"""
|
|
|
|
|
Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for
|
|
|
|
|