!637 Learning rate and weight decay making group params

Merge pull request !637 from ghzl/learning-rate-make-group-mode
pull/637/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit deae380969

@ -103,9 +103,9 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor")
def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1,
def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
moment2):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
@ -136,9 +136,27 @@ class Adam(Optimizer):
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
Note:
The Adam optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
@ -161,8 +179,6 @@ class Adam(Optimizer):
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -172,15 +188,26 @@ class Adam(Optimizer):
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Adam(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
@ -216,10 +243,14 @@ class Adam(Optimizer):
self.beta1_power = beta1_power
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power
success = self.hyper_map(F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps),
gradients, params, moment1, moment2)
if self.is_group:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
else:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
return success
@ -262,6 +293,8 @@ class AdamWeightDecay(Optimizer):
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecay, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
@ -330,6 +363,8 @@ class AdamWeightDecayDynamicLR(Optimizer):
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name,
warmup_steps=0):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations

@ -96,7 +96,8 @@ class FTRL(Optimizer):
def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(FTRL, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay,
self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum)

@ -183,6 +183,8 @@ class Lamb(Optimizer):
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
super(Lamb, self).__init__(start_learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate,
power, beta1, beta2, eps, weight_decay, self.cls_name)

@ -23,7 +23,7 @@ momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
@ -36,9 +36,27 @@ class Momentum(Optimizer):
Refer to the paper on the importance of initialization and momentum in deep learning for more details.
Note:
The Momentum optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
@ -49,8 +67,6 @@ class Momentum(Optimizer):
momentum (float): Hyperparameter of type float, means momentum for the moving average.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -63,13 +79,24 @@ class Momentum(Optimizer):
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
@ -84,5 +111,8 @@ class Momentum(Optimizer):
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
if self.is_group:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
else:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success

File diff suppressed because it is too large Load Diff

@ -22,17 +22,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
@rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
def _centered_rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
@ -44,6 +44,13 @@ class RMSProp(Optimizer):
Implements Root Mean Squared Propagation (RMSProp) algorithm.
Note:
The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Update `params` according to the RMSProp algorithm.
The equation is as follows:
@ -84,8 +91,18 @@ class RMSProp(Optimizer):
represents `gradients`.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
@ -95,15 +112,13 @@ class RMSProp(Optimizer):
Other cases are not supported. Default: 0.1.
decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9.
momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
greater than 0.Default: 0.0.
greater than 0. Default: 0.0.
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than
0. Default: 1e-10.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -113,14 +128,25 @@ class RMSProp(Optimizer):
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.RMSProp(params=net.trainable_params(), learning_rate=lr)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr)
>>> model = Model(net, loss, opt)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0):
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("decay", decay, [float], self.cls_name)
validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
@ -150,9 +176,18 @@ class RMSProp(Optimizer):
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.centered:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.mg, self.ms, self.moment, gradients)
if self.is_group:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.ms, self.moment, gradients)
if self.is_group:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.ms, self.moment, gradients)
return success

@ -24,7 +24,7 @@ sgd_opt = C.MultitypeFuncGraph("sgd_opt")
@sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat):
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
@ -39,9 +39,27 @@ class SGD(Optimizer):
Nesterov momentum is based on the formula from paper `On the importance of initialization and
momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
Note:
The SGD optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
@ -67,9 +85,21 @@ class SGD(Optimizer):
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.SGD(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,
loss_scale=1.0):
@ -109,5 +139,8 @@ class SGD(Optimizer):
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat)
if self.is_group:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
else:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
return success

@ -167,7 +167,7 @@ class TrainOneStepCell(Cell):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens

@ -50,7 +50,7 @@ class NetWithoutWeight(nn.Cell):
def test_adamwithoutparam():
net = NetWithoutWeight()
net.set_train()
with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"):
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
@ -104,5 +104,5 @@ def test_AdamWeightDecayDynamicLR():
def test_adam_mindspore_flatten():
net = nn.Flatten()
with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"):
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
AdamWeightDecay(net.get_parameters())

@ -69,19 +69,19 @@ class TestSGD():
class TestNullParam():
""" TestNullParam definition """
def test_optim_init(self):
with pytest.raises(TypeError):
with pytest.raises(ValueError):
Optimizer(0.1, None)
def test_AdamWightDecay_init(self):
with pytest.raises(TypeError):
with pytest.raises(ValueError):
AdamWeightDecay(None)
def test_AdamWeightDecayDynamicLR_init(self):
with pytest.raises(TypeError):
with pytest.raises(ValueError):
AdamWeightDecayDynamicLR(None, 10)
def test_Sgd_init(self):
with pytest.raises(TypeError):
with pytest.raises(ValueError):
SGD(None)
class TestUnsupportParam():

@ -0,0 +1,210 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.nn.optim import Momentum, SGD, RMSProp, Adam
from mindspore import context
from mindspore.common.api import _executor
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.nn import TrainOneStepCell, WithLossCell
context.set_context(mode=context.GRAPH_MODE)
class LeNet5(nn.Cell):
""" LeNet5 definition """
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def test_group_lr():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_lr = 0.8
default_lr = 0.1
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
assert opt.is_group is True
assert opt.dynamic_lr is False
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
assert lr.data == Tensor(conv_lr, mstype.float32)
else:
assert lr.data == Tensor(default_lr, mstype.float32)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_dynamic_1():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_lr = 0.8
default_lr = (0.1, 0.2, 0.3)
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
assert opt.is_group is True
assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32))
else:
assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32))
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_dynamic_2():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_lr = (0.1, 0.2, 0.3)
default_lr = 0.8
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = RMSProp(group_params, learning_rate=default_lr)
assert opt.is_group is True
assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))
else:
assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_dynamic_no_same_size():
net = LeNet5()
conv_lr = (0.1, 0.2, 0.3)
default_lr = (0.1, 0.2)
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
with pytest.raises(ValueError):
Momentum(group_params, learning_rate=default_lr, momentum=0.9)
def test_group_not_float_lr():
net = LeNet5()
conv_lr = 1
default_lr = 0.3
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
with pytest.raises(TypeError):
Momentum(group_params, learning_rate=default_lr, momentum=0.9)
def test_group_not_float_weight_decay():
net = LeNet5()
conv_weight_decay = 1
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
{'params': no_conv_params}]
with pytest.raises(TypeError):
Momentum(group_params, learning_rate=0.1, momentum=0.9)
def test_weight_decay():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_weight_decay = 0.8
default_weight_decay = 0.0
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay)
assert opt.is_group is True
for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters):
if param in conv_params:
assert weight_decay == conv_weight_decay
assert decay_flags is True
else:
assert weight_decay == default_weight_decay
assert decay_flags is False
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_repeat_param():
net = LeNet5()
conv_lr = 0.1
default_lr = 0.3
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': conv_params, 'lr': default_lr},
{'params': no_conv_params}]
with pytest.raises(RuntimeError):
Adam(group_params, learning_rate=default_lr)
Loading…
Cancel
Save