From c9b7d95c2cffe712a2ded4c56be5839a6581f198 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Wed, 3 Jun 2020 18:13:11 +0800 Subject: [PATCH] fix lr check bug in AdamWeightDecayDynamicLR --- mindspore/nn/optim/adam.py | 12 +++++------- mindspore/nn/optim/lamb.py | 11 ++++++----- tests/ut/python/nn/optim/test_adam.py | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index a256f0e0d8..704cbdf708 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -26,8 +26,6 @@ from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer -_learning_rate_update_func = ['linear', 'cos', 'sin'] - adam_opt = C.MultitypeFuncGraph("adam_opt") @@ -94,10 +92,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): """Check the type of inputs.""" - validator.check_float_positive('learning_rate', learning_rate, prim_name) - validator.check_float_legal_value('learning_rate', learning_rate, prim_name) - validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name) - validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name) + validator.check_value_type("learning_rate", learning_rate, [float], prim_name) + validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) + validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) + validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_float_positive('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) @@ -363,7 +361,7 @@ class AdamWeightDecayDynamicLR(Optimizer): eps=1e-6, weight_decay=0.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) + super(AdamWeightDecayDynamicLR, self).__init__(0.0, params) if self.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 420125684b..d8cc5b4ce4 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -111,10 +111,12 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para def _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) - validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) + validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) + validator.check_number_range("start_learning_rate rate", start_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, + prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) - validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name) + validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, + prim_name) validator.check_float_positive('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) @@ -180,8 +182,7 @@ class Lamb(Optimizer): eps=1e-6, weight_decay=0.0, decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): - - super(Lamb, self).__init__(start_learning_rate, params) + super(Lamb, self).__init__(0.0, params) if self.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index e47a0d6704..5e6b6b129a 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -104,7 +104,7 @@ def test_AdamWeightDecayDynamicLR(): _executor.compile(train_network, inputs, label) -def test_adam_mindspore_flatten(): +def test_adam_mindspore_with_empty_params(): net = nn.Flatten() with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): AdamWeightDecay(net.get_parameters())