uniform learning_rate behavior of optimizers

pull/2755/head
wangnan39@huawei.com 5 years ago
parent 57252dee24
commit 082433183d

@ -231,8 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
>>> cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator.check_float_positive('min_lr', min_lr, None)
validator.check_float_legal_value('min_lr', min_lr, None)
if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('max_lr', max_lr, None)
validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
@ -288,8 +289,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
"""
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate, None)
validator.check_float_legal_value('end_learning_rate', end_learning_rate, None)
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('power', power, None)
validator.check_float_legal_value('power', power, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
@ -311,11 +313,58 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
return lr
def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
r"""
Get learning rate warming up.
For the i-th step, the formula of computing warmup_learning_rate[i] is:
.. math::
warmup\_learning\_rate[i] = learning\_rate * tmp\_epoch / tmp\_warmup\_epoch
Where :math:`tmp\_epoch=min(current\_epoch, warmup\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`
Args:
learning_rate (float): The initial value of learning rate.
warmup_steps (int): The warm up steps of learning rate.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> total_step = 6
>>> step_per_epoch = 2
>>> warmup_epoch = 2
>>> warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch)
[0.0, 0.0, 0.05, 0.05, 0.1, 0.1]
"""
if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
function = lambda x, y: (x, min(x, y))
lr = []
for i in range(total_step):
current_epoch = math.floor(i / step_per_epoch)
warmup_epoch, tmp_epoch = function(warmup_epoch, current_epoch)
lr.append(learning_rate * tmp_epoch/ warmup_epoch)
return lr
__all__ = [
'piecewise_constant_lr',
'exponential_decay_lr',
'natural_exp_decay_lr',
'inverse_decay_lr',
'cosine_decay_lr',
'polynomial_decay_lr'
'polynomial_decay_lr',
'warmup_lr'
]

File diff suppressed because it is too large Load Diff

@ -20,7 +20,7 @@ The optimizer is used to calculate and update the gradients.
"""
from .optimizer import Optimizer
from .momentum import Momentum
from .adam import Adam, PSAdam, AdamWeightDecay, AdamWeightDecayDynamicLR
from .adam import Adam, PSAdam, AdamWeightDecay
from .lamb import Lamb
from .sgd import SGD
from .lars import LARS
@ -30,4 +30,4 @@ from .proximal_ada_grad import ProximalAdagrad
from .lazyadam import LazyAdam
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam',
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad']
'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad']

File diff suppressed because it is too large Load Diff

@ -24,9 +24,9 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
_ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt")
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", "Tensor",
@_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "IndexedSlices", "Tensor",
"Tensor", "Bool")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment,
def _tensor_run_opt_with_sparse(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment,
ps_parameter):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True
@ -43,9 +43,9 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power,
return success
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
@_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool")
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, ps_parameter):
def _tensor_run_opt(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment, ps_parameter):
"""Apply ftrl optimizer to the weight parameter."""
success = True
if ps_parameter:
@ -83,7 +83,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2,
return success
def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None):
def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None):
"""Check param."""
validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
@ -99,9 +99,6 @@ def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0,
validator.check_value_type("use_locking", use_locking, [bool], prim_name)
validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
class FTRL(Optimizer):
"""
@ -113,15 +110,34 @@ class FTRL(Optimizer):
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on all of the parameters.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be Parameter.
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Using different learning rate by separating parameters is currently not supported.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (float): The learning rate value, should be positive. Default: 0.001.
learning_rate (float): The learning rate value, should be zero or positive, dynamic learning rate is currently
not supported. Default: 0.001.
lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less
than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5.
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
@ -139,23 +155,36 @@ class FTRL(Optimizer):
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.FTRL(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params},
>>> {'order_params': net.trainable_params()}]
>>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use weight decay of 0.01.
>>> # The no_conv_params's parameters will use default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.FTRL(net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(FTRL, self).__init__(learning_rate, params, loss_scale=loss_scale)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name)
super(FTRL, self).__init__(learning_rate, params, weight_decay, loss_scale=loss_scale)
if self.dynamic_lr or self.is_group_lr:
raise ValueError('Dynamic learning rate or group learning rate is currently not supported.')
_check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1
self.l2 = l2
self.lr_power = lr_power
self.weight_decay = weight_decay
self.decay_tf = tuple((lambda: True)() for x in self.parameters)
if not self.is_group:
self.decay_flags = tuple((lambda: True)() for x in self.parameters)
self.hyper_map = C.HyperMap()
self.opt = P.ApplyFtrl(use_locking=use_locking)
self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
@ -164,12 +193,11 @@ class FTRL(Optimizer):
params = self.parameters
moments = self.moments
linear = self.linear
lr = self.learning_rate
if self.weight_decay > 0.0:
grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.decay_weight(grads)
grads = self.scale_grad(grads)
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
lr = self.get_lr()
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self.l1, self.l2, self.lr_power, lr),
linear, grads, params, moments, self.ps_parameters)
return success
@ -180,7 +208,7 @@ class PSFTRL(Optimizer):
super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name)
_check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1

File diff suppressed because it is too large Load Diff

@ -38,14 +38,14 @@ def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_f
return gradient
def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name):
validator.check_value_type("optimizer", optimizer, Optimizer, prim_name)
if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name:
raise TypeError("LARS can not be used with ", optimizer.cls_name)
validator.check_value_type("epsilon", epsilon, [float], prim_name)
validator.check_value_type("coefficient", coefficient, [float], prim_name)
validator.check_value_type("use_clip", use_clip, [bool], prim_name)
class LARS(Optimizer):
"""
Implements the LARS algorithm with LARSUpdate Operator.
@ -81,45 +81,71 @@ class LARS(Optimizer):
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
_check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name)
self.opt = optimizer
self.parameters = optimizer.parameters
self.use_clip = use_clip
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
self.is_group = optimizer.is_group
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.decay_flags = optimizer.decay_flags
self.reciprocal_scale = optimizer.reciprocal_scale
self.hyper_map = C.HyperMap()
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.cast = P.Cast()
self.parameters = optimizer.parameters
if use_clip is True:
self.learning_rate = optimizer.learning_rate
if use_clip:
self.is_group_lr = optimizer.is_group_lr
self.dynamic_lr = optimizer.dynamic_lr
self.gather = optimizer.gather
self.assignadd = optimizer.assignadd
self.origin_learning_rate = optimizer.learning_rate
self.global_step = optimizer.global_step
else:
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.reciprocal_scale = optimizer.reciprocal_scale
optimizer.reciprocal_scale = 1.0
self.is_group = optimizer.is_group
if self.is_group_lr and self.dynamic_lr:
raise ValueError('Grouped dynamic learning rate is currently not supported for the inputs optimizer ' \
'of lars.')
if self.is_group:
self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay))
optimizer.weight_decay = tuple(map(lambda x: 0.0, optimizer.weight_decay))
else:
self.weight_decay = optimizer.weight_decay / optimizer.loss_scale
optimizer.weight_decay = 0.0
optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags))
optimizer.reciprocal_scale = 1.0
optimizer.exec_weight_decay = False
optimizer.weight_decay = 0.0
self.decay_flags = optimizer.decay_flags
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap()
def _get_lr(self):
"""Get the learning rate of current step."""
lr = self.origin_learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.origin_learning_rate:
current_dynamic_lr = learning_rate(self.global_step)
lr += (current_dynamic_lr,)
else:
lr = self.origin_learning_rate(self.global_step)
return lr
def construct(self, gradients):
params = self.parameters
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, 0)
F.control_depend(lr, self.assignadd(self.global_step, 1))
if self.use_clip:
lr = self._get_lr()
else:
lr = self.learning_rate
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
if self.is_group:
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay,
gradients, params, self.decay_flags, self.lars_flag)
if self.is_group_lr:
gradients = self.hyper_map(F.partial(_lars_opt, self.lars), lr, self.weight_decay,
gradients, params, self.decay_flags, self.lars_flag)
else:
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay,
gradients, params, self.decay_flags, self.lars_flag)
else:
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay),
gradients, params, self.decay_flags, self.lars_flag)
success = self.opt(grad_t)
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay),
gradients, params, self.decay_flags, self.lars_flag)
success = self.opt(gradients)
return success

@ -84,12 +84,11 @@ class LazyAdam(Optimizer):
:math:`\epsilon` represents `eps`.
Note:
The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse behavior, to be notice, is not equivalent to the
@ -113,13 +112,14 @@ class LazyAdam(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
0.9.
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
@ -153,9 +153,9 @@ class LazyAdam(Optimizer):
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params, 'lr': 0.01},
>>> {'order_params': net.trainable_params()}]
>>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01.
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0.
>>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()

@ -47,12 +47,9 @@ class Momentum(Optimizer):
Refer to the paper on the importance of initialization and momentum in deep learning for more details.
Note:
The Momentum optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
@ -73,14 +70,13 @@ class Momentum(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a
Tensor but the dims of the Tensor is 0, use fixed learning
rate. Other cases are not supported. It should be equal to
or greater than 0.0.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
momentum (float): Hyperparameter of type float, means momentum for the moving average.
It should be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.

File diff suppressed because it is too large Load Diff

@ -32,7 +32,7 @@ def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
def _tensor_run_opt(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
"""Apply proximal_ada_grad optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient))
@ -59,15 +59,42 @@ class ProximalAdagrad(Optimizer):
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be Parameter.
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (float): The learning rate value, must be greater than or equal to zero. Default: 0.001.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 0.001.
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If True use locks for update operation. Default: False.
@ -83,21 +110,31 @@ class ProximalAdagrad(Optimizer):
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.ProximalAdagrad(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params, 'lr': 0.01},
>>> {'order_params': net.trainable_params()}]
>>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.ProximalAdagrad(net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(accum, l1, l2, use_locking, self.cls_name)
self.accum = self.parameters.clone(prefix="accum", init=accum)
self.l1 = Tensor(l1, mstype.float32)
self.l2 = Tensor(l2, mstype.float32)
self.weight_decay = weight_decay
self.hyper_map = C.HyperMap()
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking)
@ -107,7 +144,11 @@ class ProximalAdagrad(Optimizer):
accum = self.accum
grads = self.decay_weight(grads)
grads = self.scale_grad(grads)
lr = self.learning_rate
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2),
grads, params, accum)
lr = self.get_lr()
if self.is_group_lr:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
grads, params, accum)
else:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr),
grads, params, accum)
return success

@ -44,12 +44,9 @@ class RMSProp(Optimizer):
Implements Root Mean Squared Propagation (RMSProp) algorithm.
Note:
The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
@ -109,13 +106,14 @@ class RMSProp(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 0.1.
decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9.
momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
greater than 0. Default: 0.0.

@ -40,14 +40,11 @@ class SGD(Optimizer):
momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
Note:
The SGD optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
@ -66,14 +63,14 @@ class SGD(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. It should be equal to or
greater than 0. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 0.1.
momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0.
dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.

@ -14,9 +14,10 @@
# ============================================================================
"""Learning scheduler."""
from math import ceil
import numpy as np
import mindspore.nn.learning_rate_schedule as lr_schedules
def square_root_schedule(lr, update_num, decay_start_step,
warmup_steps=2000,
@ -105,3 +106,35 @@ def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
return lrs
class BertLearningRate(lr_schedules.LearningRateSchedule):
"""
Implements of warmup-polydecay learning rate scheduler.
Args:
learning_rate (float): The initial value of learning rate.
end_learning_rate (float): The end value of learning rate.
warmup_steps (int): The warm up steps of learning rate.
decay_steps (int): A value used to calculate decayed learning rate.
power (float): A value used to calculate decayed learning rate.
Returns:
Tensor. The learning rate value for the current step.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr

@ -37,7 +37,7 @@ from src.transformer.infer_mass import infer
from src.utils import LossCallBack
from src.utils import one_weight, zero_weight, weight_variable
from src.utils import square_root_schedule
from src.utils.lr_scheduler import polynomial_decay_scheduler
from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate
parser = argparse.ArgumentParser(description='MASS train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
@ -178,10 +178,16 @@ def _build_training_pipeline(config: TransformerConfig,
if config.optimizer.lower() == "adam":
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)
elif config.optimizer.lower() == "lamb":
optimizer = Lamb(net_with_loss.trainable_params(), decay_steps=12000,
start_learning_rate=config.lr, end_learning_rate=config.min_lr,
power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01,
eps=1e-6)
lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
power=10.0, warmup_steps=config.warmup_steps)
decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
net_with_loss.trainable_params()))
other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
net_with_loss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
optimizer = Lamb(group_params, lr, eps=1e-6)
elif config.optimizer.lower() == "momentum":
optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9)
else:

@ -147,7 +147,7 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation):
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
Parameters for optimizer:
AdamWeightDecayDynamicLR:
AdamWeightDecay:
decay_steps steps of the learning rate decay: N
learning_rate value of learning rate: Q
end_learning_rate value of end learning rate: Q, must be positive

@ -23,12 +23,12 @@ from src.bert_for_finetune import BertFinetuneCell, BertCLS
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_classification_dataset
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
from src.utils import make_directory, LossCallBack, LoadNewestCkpt
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
@ -42,27 +42,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
decay_steps=steps_per_epoch * epoch_num,
learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=optimizer_cfg.AdamWeightDecayDynamicLR.power,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps)
if optimizer_cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
elif optimizer_cfg.optimizer == 'Lamb':
optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num,
start_learning_rate=optimizer_cfg.Lamb.start_learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_filter=optimizer_cfg.Lamb.decay_filter)
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
elif optimizer_cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
momentum=optimizer_cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]")
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)

@ -23,13 +23,13 @@ import argparse
from src.bert_for_finetune import BertFinetuneCell, BertNER
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_ner_dataset
from src.utils import make_directory, LossCallBack, LoadNewestCkpt
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
@ -44,27 +44,30 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
decay_steps=steps_per_epoch * epoch_num,
learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=optimizer_cfg.AdamWeightDecayDynamicLR.power,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps)
if optimizer_cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
elif optimizer_cfg.optimizer == 'Lamb':
optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num,
start_learning_rate=optimizer_cfg.Lamb.start_learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_filter=optimizer_cfg.Lamb.decay_filter)
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
elif optimizer_cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
momentum=optimizer_cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]")
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)

@ -28,12 +28,12 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack
from src.utils import LossCallBack, BertLearningRate
_current_dir = os.path.dirname(os.path.realpath(__file__))
@ -109,24 +109,35 @@ def run_pretrain():
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
if cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count,
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
eps=cfg.Lamb.eps)
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=cfg.Lamb.warmup_steps,
decay_steps=ds.get_dataset_size() * new_repeat_count,
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
decay_steps=ds.get_dataset_size() * new_repeat_count,
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps,
warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
elif cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
decay_steps=ds.get_dataset_size() * new_repeat_count,
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
if args_opt.enable_save_ckpt == "true":

@ -25,12 +25,12 @@ from src.dataset import create_squad_dataset
from src import tokenization
from src.create_squad_data import read_squad_examples, convert_examples_to_features
from src.run_squad import write_predictions
from src.utils import make_directory, LossCallBack, LoadNewestCkpt
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
@ -44,27 +44,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
decay_steps=steps_per_epoch * epoch_num,
learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=optimizer_cfg.AdamWeightDecayDynamicLR.power,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps)
if optimizer_cfg.optimizer == 'AdamWeightDecay':
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
elif optimizer_cfg.optimizer == 'Lamb':
optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num,
start_learning_rate=optimizer_cfg.Lamb.start_learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_filter=optimizer_cfg.Lamb.decay_filter)
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
elif optimizer_cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
momentum=optimizer_cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]")
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)

@ -24,20 +24,22 @@ cfg = edict({
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'AdamWeightDecay': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 5.0,
'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
'warmup_steps': 10000,
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
}),
'Momentum': edict({

@ -23,19 +23,20 @@ from .bert_model import BertConfig
optimizer_cfg = edict({
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'AdamWeightDecay': edict({
'learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
}),
'Lamb': edict({
'start_learning_rate': 2e-5,
'learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 0.01,
'decay_filter': lambda x: False,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
}),
'Momentum': edict({
'learning_rate': 2e-5,

@ -23,6 +23,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
class CrossEntropyCalculation(nn.Cell):
@ -123,3 +124,25 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre
max_num = int(num)
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
return load_finetune_checkpoint_path
class BertLearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for Bert network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr

@ -30,7 +30,7 @@ verification_set = [
'block': {
'model': network,
'loss': SquaredLoss(),
'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01),
'opt': Lamb(network.trainable_params(), 0.02, weight_decay=0.01),
'num_epochs': num_epochs,
'loss_upper_bound': 0.3,
},

@ -31,7 +31,7 @@ Example:
'block': {
'model': network,
'loss': SquaredLoss(),
'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01),
'opt': Lamb(network.trainable_params(), lr=0.02, weight_decay=0.01),
'num_epochs': num_epochs,
'loss_upper_bound': 0.3,
},

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save