|
|
|
@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num,
|
|
|
|
|
start_lr,
|
|
|
|
|
end_lr,
|
|
|
|
|
verbose=False):
|
|
|
|
|
if epoch_num < warmup_steps:
|
|
|
|
|
tmp = epoch_num - warmup_steps
|
|
|
|
|
if tmp < 0:
|
|
|
|
|
return start_lr + (end_lr - start_lr) * (float(epoch_num) /
|
|
|
|
|
float(warmup_steps))
|
|
|
|
|
elif paddle.in_dynamic_mode():
|
|
|
|
|
if tmp < 3:
|
|
|
|
|
return 0.5
|
|
|
|
|
elif tmp < 6:
|
|
|
|
|
return 0.2
|
|
|
|
|
else:
|
|
|
|
|
return 0.1
|
|
|
|
|
else:
|
|
|
|
|
return learning_rate
|
|
|
|
|
return 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_step_lr(epoch_num,
|
|
|
|
@ -407,6 +415,9 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
|
paddle.disable_static(place)
|
|
|
|
|
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
|
|
|
|
|
linear = paddle.nn.Linear(10, 10)
|
|
|
|
|
if paddle_api.__name__ == "LinearWarmup":
|
|
|
|
|
kwarg['learning_rate'] = paddle.optimizer.lr.PiecewiseDecay(
|
|
|
|
|
[3, 6], [0.5, 0.2, 0.1])
|
|
|
|
|
scheduler = paddle_api(**kwarg)
|
|
|
|
|
adam = paddle.optimizer.Adam(
|
|
|
|
|
learning_rate=scheduler, parameters=linear.parameters())
|
|
|
|
@ -420,12 +431,26 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
|
adam.clear_grad()
|
|
|
|
|
current_lr = adam.get_lr()
|
|
|
|
|
expected_lr = python_func(epoch, **kwarg)
|
|
|
|
|
if paddle_api.__name__ != "CosineAnnealingDecay":
|
|
|
|
|
self.assertEqual(current_lr, expected_lr)
|
|
|
|
|
scheduler.step()
|
|
|
|
|
else:
|
|
|
|
|
if paddle_api.__name__ == "CosineAnnealingDecay":
|
|
|
|
|
self.assertAlmostEqual(current_lr, expected_lr)
|
|
|
|
|
scheduler.step(epoch + 1)
|
|
|
|
|
elif paddle_api.__name__ == "LinearWarmup":
|
|
|
|
|
self.assertAlmostEqual(current_lr, expected_lr)
|
|
|
|
|
state_dict = adam.state_dict()
|
|
|
|
|
scheduler1 = paddle.optimizer.lr.LinearWarmup(**kwarg)
|
|
|
|
|
adam1 = paddle.optimizer.Adam(
|
|
|
|
|
learning_rate=scheduler1, parameters=linear.parameters())
|
|
|
|
|
adam1.set_state_dict(state_dict)
|
|
|
|
|
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
|
|
|
|
|
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
|
|
|
|
|
self.assertEqual(scheduler.learning_rate.last_lr,
|
|
|
|
|
scheduler1.learning_rate.last_lr)
|
|
|
|
|
self.assertEqual(scheduler.learning_rate.last_epoch,
|
|
|
|
|
scheduler1.learning_rate.last_epoch)
|
|
|
|
|
scheduler.step()
|
|
|
|
|
else:
|
|
|
|
|
self.assertEqual(current_lr, expected_lr)
|
|
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
|
|
def test_scheduler(self):
|
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
|
|
@ -464,8 +489,7 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
|
"decay_steps": 20,
|
|
|
|
|
"end_lr": 0,
|
|
|
|
|
"power": 1.0,
|
|
|
|
|
"cycle": False,
|
|
|
|
|
"verbose": True
|
|
|
|
|
"cycle": False
|
|
|
|
|
}), (polynomial_lr, paddle.optimizer.lr.PolynomialDecay, {
|
|
|
|
|
"learning_rate": 0.5,
|
|
|
|
|
"decay_steps": 20,
|
|
|
|
@ -475,10 +499,9 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
|
"verbose": False
|
|
|
|
|
}), (linear_warmup_lr, paddle.optimizer.lr.LinearWarmup, {
|
|
|
|
|
'learning_rate': 0.5,
|
|
|
|
|
'warmup_steps': 20,
|
|
|
|
|
'warmup_steps': 10,
|
|
|
|
|
'start_lr': 0,
|
|
|
|
|
'end_lr': 0.5,
|
|
|
|
|
"verbose": True
|
|
|
|
|
'end_lr': 0.5
|
|
|
|
|
}), (exponential_lr, paddle.optimizer.lr.ExponentialDecay, {
|
|
|
|
|
"learning_rate": 0.5,
|
|
|
|
|
"gamma": 0.9,
|
|
|
|
@ -486,8 +509,7 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
|
}), (multi_step_lr, paddle.optimizer.lr.MultiStepDecay, {
|
|
|
|
|
"learning_rate": 0.5,
|
|
|
|
|
"milestones": [3, 6, 9, 15, 20],
|
|
|
|
|
"gamma": 0.8,
|
|
|
|
|
"verbose": True
|
|
|
|
|
"gamma": 0.8
|
|
|
|
|
}), (step_lr, paddle.optimizer.lr.StepDecay, {
|
|
|
|
|
"learning_rate": 0.5,
|
|
|
|
|
"step_size": 2,
|
|
|
|
@ -510,7 +532,7 @@ class TestLRScheduler(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
for place in places:
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
#self._test_static(python_func, paddle_api, kwarg, place)
|
|
|
|
|
self._test_static(python_func, paddle_api, kwarg, place)
|
|
|
|
|
paddle.disable_static(place)
|
|
|
|
|
self._test_dygraph(python_func, paddle_api, kwarg, place)
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|