fix LinearLrWarmup bug; test=develop (#24913)

fix_copy_if_different
hong 5 years ago committed by GitHub
parent f6f7df9cd5
commit dbc3fd5eb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -664,6 +664,7 @@ class LinearLrWarmup(LearningRateDecay):
format(learning_rate))
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.start_lr = start_lr
assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
end_lr, start_lr)
self.lr_ratio_before_warmup = (
@ -676,7 +677,7 @@ class LinearLrWarmup(LearningRateDecay):
from .. import layers
if self.step_num < self.warmup_steps:
return self.lr_ratio_before_warmup * self.step_num
return self.lr_ratio_before_warmup * self.step_num + self.start_lr
else:
return base_lr

Loading…
Cancel
Save