From e7560214be8da27be506c0b4d000b0a78b4d209a Mon Sep 17 00:00:00 2001 From: Ziyan Date: Sat, 23 May 2020 20:19:36 +0800 Subject: [PATCH] add lars parameter check --- mindspore/nn/optim/lars.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 2efbb17d0b..3d85a05867 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -21,6 +21,7 @@ from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F +from mindspore._checkparam import Validator as validator from .optimizer import grad_scale, Optimizer lars_opt = C.MultitypeFuncGraph("lars_opt") @@ -41,6 +42,11 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f return gradient +def _check_param_value(optimizer, epsilon, hyperpara, use_clip, prim_name): + validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) + validator.check_value_type("epsilon", epsilon, [float], prim_name) + validator.check_value_type("hyperpara", hyperpara, [float], prim_name) + validator.check_value_type("use_clip", use_clip, [bool], prim_name) class LARS(Optimizer): """ @@ -79,9 +85,10 @@ class LARS(Optimizer): def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): - super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")]) + super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")], weight_decay, loss_scale) if optimizer.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") + _check_param_value(optimizer, epsilon, hyperpara, use_clip, self.cls_name) self.opt = optimizer self.parameters = optimizer.parameters self.learning_rate = optimizer.learning_rate