diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index e3ab616ddd..2efbb17d0b 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -80,6 +80,8 @@ class LARS(Optimizer): decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")]) + if optimizer.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") self.opt = optimizer self.parameters = optimizer.parameters self.learning_rate = optimizer.learning_rate diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index d931e5a52f..05560d9739 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -81,7 +81,7 @@ class Optimizer(Cell): raise ValueError("Optimizer got an empty parameter list.") if not isinstance(parameters[0], (dict, Parameter)): - raise ValueError("Only a list of Parameter or dict can be supported.") + raise TypeError("Only a list of Parameter or dict can be supported.") if isinstance(loss_scale, int): loss_scale = float(loss_scale) @@ -258,9 +258,9 @@ class Optimizer(Cell): for param in group_param['params']: validator.check_value_type("parameter", param, [Parameter], self.cls_name) - if param in params_store: + if param.name in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") - params_store.append(param) + params_store.append(param.name) self.group_lr.append(Parameter(lr, name="lr_" + param.name)) self.group_weight_decay.append(weight_decay_) @@ -298,18 +298,22 @@ class Optimizer(Cell): Parameter, single `Parameter` or `list[Parameter]` according to the input type. """ if not isinstance(param, (Parameter, list)): - raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.") + raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") if isinstance(param, list): lr = [] for p in param: validator.check_value_type("parameter", p, [Parameter], self.cls_name) + if p not in self.parameters: + raise ValueError(f"The parameter {p.name} is not in optimizer.") if self.is_group_lr: index = self.parameters.index(p) lr.append(self.learning_rate[index]) else: lr.append(self.learning_rate) else: + if param not in self.parameters: + raise ValueError(f"The parameter {param.name} is not in optimizer.") if self.is_group_lr: index = self.parameters.index(param) lr = self.learning_rate[index] diff --git a/tests/ut/python/nn/optim/test_optimizer.py b/tests/ut/python/nn/optim/test_optimizer.py index 6594550687..548094840e 100644 --- a/tests/ut/python/nn/optim/test_optimizer.py +++ b/tests/ut/python/nn/optim/test_optimizer.py @@ -94,7 +94,7 @@ class TestUnsupportParam(): """ TestUnsupportParam definition """ def test_optim_init(self): - with pytest.raises(ValueError): + with pytest.raises(TypeError): Optimizer(0.1, (1, 2, 3)) def test_AdamWightDecay_init(self):