|
|
@ -537,6 +537,18 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
self._test_dygraph(python_func, paddle_api, kwarg, place)
|
|
|
|
self._test_dygraph(python_func, paddle_api, kwarg, place)
|
|
|
|
paddle.enable_static()
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_linear_warmp(self):
|
|
|
|
|
|
|
|
natural_lr = paddle.optimizer.lr.NaturalExpDecay(
|
|
|
|
|
|
|
|
learning_rate=0.5, gamma=0.1)
|
|
|
|
|
|
|
|
natural_lr_warmup = paddle.optimizer.lr.LinearWarmup(
|
|
|
|
|
|
|
|
learning_rate=natural_lr, warmup_steps=10, start_lr=0.0, end_lr=0.1)
|
|
|
|
|
|
|
|
for idx in range(30):
|
|
|
|
|
|
|
|
if idx >= 10:
|
|
|
|
|
|
|
|
self.assertEqual(natural_lr_warmup.get_lr(),
|
|
|
|
|
|
|
|
natural_lr.get_lr())
|
|
|
|
|
|
|
|
natural_lr.step()
|
|
|
|
|
|
|
|
natural_lr_warmup.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|
|
|
|
unittest.main()
|
|
|
|