learning rate and weight decay support group mode

pull/637/head
guohongzilong 5 years ago
parent 883fde0494
commit 824bc30a94

@ -102,9 +102,9 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor") "Tensor")
def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1, def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1,
moment2): moment2):
"""Apply adam optimizer to the weight parameter using Tensor.""" """Apply adam optimizer to the weight parameter using Tensor."""
success = True success = True
@ -135,9 +135,27 @@ class Adam(Optimizer):
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`, `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`. :math:`\epsilon` represents `eps`.
Note:
The Adam optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be class mindspore.Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1, Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will use dynamic learning rate, then the i-th step will
@ -160,8 +178,6 @@ class Adam(Optimizer):
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
1.0. 1.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -171,15 +187,26 @@ class Adam(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Adam(params=net.trainable_params()) >>> optim = nn.Adam(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
""" """
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0, use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
@ -215,10 +242,14 @@ class Adam(Optimizer):
self.beta1_power = beta1_power self.beta1_power = beta1_power
beta2_power = self.beta2_power * self.beta2 beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power self.beta2_power = beta2_power
success = self.hyper_map(F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power, self.beta1, if self.is_group:
self.beta2, self.eps), success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
gradients, params, moment1, moment2) self.beta2, self.eps),
lr, gradients, params, moment1, moment2)
else:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps, lr),
gradients, params, moment1, moment2)
return success return success
@ -261,6 +292,8 @@ class AdamWeightDecay(Optimizer):
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecay, self).__init__(learning_rate, params) super(AdamWeightDecay, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
@ -328,6 +361,8 @@ class AdamWeightDecayDynamicLR(Optimizer):
weight_decay=0.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name) _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations # turn them to scalar when me support scalar/tensor mix operations

@ -96,7 +96,8 @@ class FTRL(Optimizer):
def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0): use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(FTRL, self).__init__(learning_rate, params) super(FTRL, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay,
self.cls_name) self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.moments = self.parameters.clone(prefix="moments", init=initial_accum)

@ -183,6 +183,8 @@ class Lamb(Optimizer):
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
super(Lamb, self).__init__(start_learning_rate, params) super(Lamb, self).__init__(start_learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate,
power, beta1, beta2, eps, weight_decay, self.cls_name) power, beta1, beta2, eps, weight_decay, self.cls_name)

@ -23,7 +23,7 @@ momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor.""" """Apply momentum optimizer to the weight parameter using Tensor."""
success = True success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
@ -36,9 +36,27 @@ class Momentum(Optimizer):
Refer to the paper on the importance of initialization and momentum in deep learning for more details. Refer to the paper on the importance of initialization and momentum in deep learning for more details.
Note:
The Momentum optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be class mindspore.Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1, Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will use dynamic learning rate, then the i-th step will
@ -49,8 +67,6 @@ class Momentum(Optimizer):
momentum (float): Hyperparameter of type float, means momentum for the moving average. momentum (float): Hyperparameter of type float, means momentum for the moving average.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. loss_scale (float): A floating point value for the loss scale. Default: 1.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -63,13 +79,24 @@ class Momentum(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0):
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
@ -84,5 +111,8 @@ class Momentum(Optimizer):
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) if self.is_group:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
else:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success return success

File diff suppressed because it is too large Load Diff

@ -22,17 +22,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") @rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success return success
@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", @centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor") "Tensor", "Tensor")
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): def _centered_rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
@ -44,6 +44,13 @@ class RMSProp(Optimizer):
Implements Root Mean Squared Propagation (RMSProp) algorithm. Implements Root Mean Squared Propagation (RMSProp) algorithm.
Note: Note:
The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Update `params` according to the RMSProp algorithm. Update `params` according to the RMSProp algorithm.
The equation is as follows: The equation is as follows:
@ -84,8 +91,18 @@ class RMSProp(Optimizer):
represents `gradients`. represents `gradients`.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be class mindspore.Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1, Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will use dynamic learning rate, then the i-th step will
@ -95,15 +112,13 @@ class RMSProp(Optimizer):
Other cases are not supported. Default: 0.1. Other cases are not supported. Default: 0.1.
decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9.
momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
greater than 0.Default: 0.0. greater than 0. Default: 0.0.
epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than
0. Default: 1e-10. 0. Default: 1e-10.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False. centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -113,14 +128,25 @@ class RMSProp(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.RMSProp(params=net.trainable_params(), learning_rate=lr)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) >>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss, opt)
""" """
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0):
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale)
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
validator.check_value_type("decay", decay, [float], self.cls_name) validator.check_value_type("decay", decay, [float], self.cls_name)
validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_value_type("momentum", momentum, [float], self.cls_name) validator.check_value_type("momentum", momentum, [float], self.cls_name)
@ -150,9 +176,18 @@ class RMSProp(Optimizer):
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.centered: if self.centered:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, if self.is_group:
self.momentum), params, self.mg, self.ms, self.moment, gradients) success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)
else: else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, if self.is_group:
self.momentum), params, self.ms, self.moment, gradients) success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.ms, self.moment, gradients)
else:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum, lr), params, self.ms, self.moment, gradients)
return success return success

@ -24,7 +24,7 @@ sgd_opt = C.MultitypeFuncGraph("sgd_opt")
@sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using Tensor.""" """Apply sgd optimizer to the weight parameter using Tensor."""
success = True success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
@ -39,9 +39,27 @@ class SGD(Optimizer):
Nesterov momentum is based on the formula from paper `On the importance of initialization and Nesterov momentum is based on the formula from paper `On the importance of initialization and
momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_. momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
Note:
The SGD optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be class mindspore.Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr" and "weight_decay" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1, Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will use dynamic learning rate, then the i-th step will
@ -67,9 +85,21 @@ class SGD(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.SGD(params=net.trainable_params()) >>> optim = nn.SGD(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01},
>>> {'params': no_conv_params}]
>>> opt = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01
>>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a
>>> # learning rate of 0.1 and a weight decay of 0.0.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
""" """
def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,
loss_scale=1.0): loss_scale=1.0):
@ -109,5 +139,8 @@ class SGD(Optimizer):
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) if self.is_group:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
else:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
return success return success

@ -167,7 +167,7 @@ class TrainOneStepCell(Cell):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens

@ -50,7 +50,7 @@ class NetWithoutWeight(nn.Cell):
def test_adamwithoutparam(): def test_adamwithoutparam():
net = NetWithoutWeight() net = NetWithoutWeight()
net.set_train() net.set_train()
with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"): with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
AdamWeightDecay(net.trainable_params(), learning_rate=0.1) AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
@ -104,5 +104,5 @@ def test_AdamWeightDecayDynamicLR():
def test_adam_mindspore_flatten(): def test_adam_mindspore_flatten():
net = nn.Flatten() net = nn.Flatten()
with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"): with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
AdamWeightDecay(net.get_parameters()) AdamWeightDecay(net.get_parameters())

@ -69,19 +69,19 @@ class TestSGD():
class TestNullParam(): class TestNullParam():
""" TestNullParam definition """ """ TestNullParam definition """
def test_optim_init(self): def test_optim_init(self):
with pytest.raises(TypeError): with pytest.raises(ValueError):
Optimizer(0.1, None) Optimizer(0.1, None)
def test_AdamWightDecay_init(self): def test_AdamWightDecay_init(self):
with pytest.raises(TypeError): with pytest.raises(ValueError):
AdamWeightDecay(None) AdamWeightDecay(None)
def test_AdamWeightDecayDynamicLR_init(self): def test_AdamWeightDecayDynamicLR_init(self):
with pytest.raises(TypeError): with pytest.raises(ValueError):
AdamWeightDecayDynamicLR(None, 10) AdamWeightDecayDynamicLR(None, 10)
def test_Sgd_init(self): def test_Sgd_init(self):
with pytest.raises(TypeError): with pytest.raises(ValueError):
SGD(None) SGD(None)
class TestUnsupportParam(): class TestUnsupportParam():

@ -0,0 +1,210 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.nn.optim import Momentum, SGD, RMSProp, Adam
from mindspore import context
from mindspore.common.api import _executor
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.nn import TrainOneStepCell, WithLossCell
context.set_context(mode=context.GRAPH_MODE)
class LeNet5(nn.Cell):
""" LeNet5 definition """
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def test_group_lr():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_lr = 0.8
default_lr = 0.1
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
assert opt.is_group is True
assert opt.dynamic_lr is False
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
assert lr.data == Tensor(conv_lr, mstype.float32)
else:
assert lr.data == Tensor(default_lr, mstype.float32)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_dynamic_1():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_lr = 0.8
default_lr = (0.1, 0.2, 0.3)
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
assert opt.is_group is True
assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32))
else:
assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32))
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_dynamic_2():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_lr = (0.1, 0.2, 0.3)
default_lr = 0.8
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = RMSProp(group_params, learning_rate=default_lr)
assert opt.is_group is True
assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))
else:
assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_dynamic_no_same_size():
net = LeNet5()
conv_lr = (0.1, 0.2, 0.3)
default_lr = (0.1, 0.2)
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
with pytest.raises(ValueError):
Momentum(group_params, learning_rate=default_lr, momentum=0.9)
def test_group_not_float_lr():
net = LeNet5()
conv_lr = 1
default_lr = 0.3
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params}]
with pytest.raises(TypeError):
Momentum(group_params, learning_rate=default_lr, momentum=0.9)
def test_group_not_float_weight_decay():
net = LeNet5()
conv_weight_decay = 1
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
{'params': no_conv_params}]
with pytest.raises(TypeError):
Momentum(group_params, learning_rate=0.1, momentum=0.9)
def test_weight_decay():
inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([1, 10]).astype(np.float32))
net = LeNet5()
conv_weight_decay = 0.8
default_weight_decay = 0.0
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
{'params': no_conv_params}]
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay)
assert opt.is_group is True
for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters):
if param in conv_params:
assert weight_decay == conv_weight_decay
assert decay_flags is True
else:
assert weight_decay == default_weight_decay
assert decay_flags is False
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt)
_executor.compile(train_network, inputs, label)
def test_group_repeat_param():
net = LeNet5()
conv_lr = 0.1
default_lr = 0.3
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': conv_params, 'lr': default_lr},
{'params': no_conv_params}]
with pytest.raises(RuntimeError):
Adam(group_params, learning_rate=default_lr)
Loading…
Cancel
Save