From eb932f717af2e4260d82c19f182d2a7d91f6b127 Mon Sep 17 00:00:00 2001 From: shippingwang Date: Fri, 22 Feb 2019 13:07:05 +0000 Subject: [PATCH 1/4] add cosine decay op, test=develop --- paddle/fluid/API.spec | 1 + .../fluid/layers/learning_rate_scheduler.py | 37 ++++++++++++++++++- .../unittests/test_learning_rate_scheduler.py | 12 ++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index df961be911..e0c8ad09c4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -334,6 +334,7 @@ paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_step paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.polynomial_decay ArgSpec(args=['learning_rate', 'decay_steps', 'end_learning_rate', 'power', 'cycle'], varargs=None, keywords=None, defaults=(0.0001, 1.0, False)) paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.cosine_decay ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.InitState.__init__ ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')) diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 617704a531..4c1996331c 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -28,10 +28,12 @@ from . import ops from . import tensor from ..initializer import init_on_cpu from ..framework import default_main_program, Parameter, unique_name, name_scope +import math __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', - 'polynomial_decay', 'piecewise_decay', 'noam_decay', 'append_LARS' + 'polynomial_decay', 'piecewise_decay', 'noam_decay', 'append_LARS', + 'cosine_decay' ] @@ -307,6 +309,39 @@ def piecewise_decay(boundaries, values): return lr +def cosine_decay(learning_rate, step_each_epoch, epochs): + """ + Applies cosine decay to the learning rate. + + when training a model, it is oftem recommended to lower the learning rate as the + training progresses. By using this function, the learning rate will be decayed by + following cosine decay strategy. + + Args: + learning_rate(Variable|float): The initial learning rate. + step_each_epoch(int): the number of steps in an epoch. + epochs(int): the number of epochs. + + Returns: + Variable: The decayed learning rate. + + Examples: + + ..code-block:: python + + base_lr = 0.1 + lr = fluid.layers.cosine_decay( + learning_rate = base_lr, step_each_epoch=10000, epochs=120) + """ + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() + + cur_epoch = ops.floor(global_step / step_each_epoch) + decayed_lr = learning_rate * 0.5 * ( + ops.cos(cur_epoch * math.pi / epochs) + 1) + return decayed_lr + + def append_LARS(params_grads, learning_rate, weight_decay): """ Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 0d3e6d73e0..5212d97dfb 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -82,6 +82,13 @@ def piecewise_decay(global_step, boundaries, values): return values[len(values) - 1] +def cosine_decay(global_step, learning_rate, step_each_epoch, epochs): + cur_epoch = math.floor(global_step / step_each_epoch) + decayed_lr = learning_rate * 0.5 * ( + math.cos(cur_epoch * math.pi / epochs) + 1) + return decayed_lr + + class TestLearningRateDecay(unittest.TestCase): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): places = [fluid.CPUPlace()] @@ -149,6 +156,11 @@ class TestLearningRateDecay(unittest.TestCase): "boundaries": [3, 6, 9], "values": [0.1, 0.2, 0.3, 0.4] }), + (cosine_decay, layers.cosine_decay, { + "learning_rate": 0.1, + "step_each_epoch": 100, + "epochs": 120 + }), ] for py_decay_fn, fluid_decay_fn, kwargs in decay_fns: From 5ce46c637a7148b6536679ffe75291386c3aa59d Mon Sep 17 00:00:00 2001 From: shippingwang Date: Tue, 26 Feb 2019 08:39:30 +0000 Subject: [PATCH 2/4] fix api.spec, test=develop --- paddle/fluid/API.spec | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e0c8ad09c4..0107161b4c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -337,6 +337,7 @@ paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varar paddle.fluid.layers.cosine_decay ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.cosine_decay ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.InitState.__init__ ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')) paddle.fluid.contrib.StateCell.__init__ ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.StateCell.compute_state ArgSpec(args=['self', 'inputs'], varargs=None, keywords=None, defaults=None) From 339829327235e7a4a3ff5bd5f81ec81fabbec11f Mon Sep 17 00:00:00 2001 From: shippingwang Date: Tue, 26 Feb 2019 08:51:50 +0000 Subject: [PATCH 3/4] add API.spec. test=develop --- paddle/fluid/API.spec | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 0107161b4c..37400f39c7 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -334,7 +334,6 @@ paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_step paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.polynomial_decay ArgSpec(args=['learning_rate', 'decay_steps', 'end_learning_rate', 'power', 'cycle'], varargs=None, keywords=None, defaults=(0.0001, 1.0, False)) paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.cosine_decay ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.cosine_decay ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None) From 733da7b2fc2322a3df8a844fa039b66e9b2b35dd Mon Sep 17 00:00:00 2001 From: shippingwang Date: Wed, 27 Feb 2019 05:45:55 +0000 Subject: [PATCH 4/4] fixed typo, test=develop --- python/paddle/fluid/layers/learning_rate_scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 4c1996331c..378aeb3760 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -313,9 +313,11 @@ def cosine_decay(learning_rate, step_each_epoch, epochs): """ Applies cosine decay to the learning rate. - when training a model, it is oftem recommended to lower the learning rate as the + when training a model, it is often recommended to lower the learning rate as the training progresses. By using this function, the learning rate will be decayed by following cosine decay strategy. + + decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1) Args: learning_rate(Variable|float): The initial learning rate.