diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index e026b1c560..97a81a590b 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -114,7 +114,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate, _ = warmup_steps validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) - validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name) + validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name) validator.check_float_positive('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name)