From 878f03ad6c262ef38122575180501506770cd947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=98=89=E7=90=AA?= Date: Mon, 3 Aug 2020 11:10:41 +0800 Subject: [PATCH] Modify logic --- mindspore/nn/optim/sgd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 4801104215..bfff475ce1 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -134,10 +134,6 @@ class SGD(Optimizer): if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) - if nesterov and (momentum <= 0.0 or dampening != 0.0): - raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," - "but got momentum {}, dampening {}".format(momentum, dampening)) - if isinstance(dampening, int): dampening = float(dampening) if not isinstance(dampening, float): @@ -151,6 +147,10 @@ class SGD(Optimizer): weight_decay = float(weight_decay) validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) + + if nesterov and (momentum <= 0.0 or dampening != 0.0): + raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," + "but got momentum {}, dampening {}".format(momentum, dampening)) self.nesterov = nesterov self.opt = P.SGD(dampening, weight_decay, nesterov)