|
|
|
@ -664,6 +664,7 @@ class LinearLrWarmup(LearningRateDecay):
|
|
|
|
|
format(learning_rate))
|
|
|
|
|
self.learning_rate = learning_rate
|
|
|
|
|
self.warmup_steps = warmup_steps
|
|
|
|
|
self.start_lr = start_lr
|
|
|
|
|
assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
|
|
|
|
|
end_lr, start_lr)
|
|
|
|
|
self.lr_ratio_before_warmup = (
|
|
|
|
@ -676,7 +677,7 @@ class LinearLrWarmup(LearningRateDecay):
|
|
|
|
|
|
|
|
|
|
from .. import layers
|
|
|
|
|
if self.step_num < self.warmup_steps:
|
|
|
|
|
return self.lr_ratio_before_warmup * self.step_num
|
|
|
|
|
return self.lr_ratio_before_warmup * self.step_num + self.start_lr
|
|
|
|
|
else:
|
|
|
|
|
return base_lr
|
|
|
|
|
|
|
|
|
|