!1721 adjust warmup_steps in AdamWeightDecayDynamicLR

Merge pull request !1721 from yoonlee666/master-wenti
pull/1721/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ca44b3919d

@ -325,9 +325,10 @@ class AdamWeightDecayDynamicLR(Optimizer):
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
decay_steps (int): The steps of the decay.
warmup_steps (int): The steps of lr warm up. Default: 0.
learning_rate (float): A floating point value for the learning rate. Default: 0.001.
end_learning_rate (float): A floating point value for the end learning rate. Default: 0.0001.
power (float): Power. Default: 10.0.
power (float): The Power of the polynomial. Default: 10.0.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
@ -353,6 +354,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
def __init__(self,
params,
decay_steps,
warmup_steps=0,
learning_rate=0.001,
end_learning_rate=0.0001,
power=10.0,
@ -360,8 +362,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
beta2=0.999,
eps=1e-6,
weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name,
warmup_steps=0):
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")

Loading…
Cancel
Save