|
|
@ -90,22 +90,22 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|
|
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
|
|
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
|
|
|
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
|
|
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
|
|
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
|
|
|
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
|
|
|
beta2_power, beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
|
|
|
|
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter):
|
|
|
|
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
|
|
|
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
|
|
|
success = True
|
|
|
|
success = True
|
|
|
|
indices = gradient.indices
|
|
|
|
indices = gradient.indices
|
|
|
|
values = gradient.values
|
|
|
|
values = gradient.values
|
|
|
|
if ps_parameter:
|
|
|
|
if ps_parameter:
|
|
|
|
op_shape = P.Shape()
|
|
|
|
op_shape = P.Shape()
|
|
|
|
shapes = (op_shape(params), op_shape(m), op_shape(v),
|
|
|
|
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
|
|
|
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
|
|
|
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
|
|
|
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
|
|
|
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
|
|
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
|
|
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
|
|
|
eps, values, indices), shapes), params))
|
|
|
|
eps, values, indices), shapes), param))
|
|
|
|
return success
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
if not target:
|
|
|
|
if not target:
|
|
|
|
success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
|
|
|
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
|
|
|
eps, values, indices))
|
|
|
|
eps, values, indices))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
op_mul = P.Mul()
|
|
|
|
op_mul = P.Mul()
|
|
|
@ -145,12 +145,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|
|
|
|
|
|
|
|
|
|
|
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
|
|
|
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
|
|
|
|
|
|
|
|
|
|
|
next_param = params - lr_t * param_update
|
|
|
|
next_param = param - lr_t * param_update
|
|
|
|
|
|
|
|
|
|
|
|
F.control_depend(assign_m, next_m)
|
|
|
|
F.control_depend(assign_m, next_m)
|
|
|
|
F.control_depend(assign_v, next_v)
|
|
|
|
F.control_depend(assign_v, next_v)
|
|
|
|
|
|
|
|
|
|
|
|
success = F.depend(success, F.assign(params, next_param))
|
|
|
|
success = F.depend(success, F.assign(param, next_param))
|
|
|
|
success = F.depend(success, F.assign(m, next_m))
|
|
|
|
success = F.depend(success, F.assign(m, next_m))
|
|
|
|
success = F.depend(success, F.assign(v, next_v))
|
|
|
|
success = F.depend(success, F.assign(v, next_v))
|
|
|
|
|
|
|
|
|
|
|
@ -160,18 +160,29 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|
|
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
|
|
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
|
|
|
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
|
|
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
|
|
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
|
|
|
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
|
|
|
beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
|
|
|
|
beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter):
|
|
|
|
"""Apply adam optimizer to the weight parameter using Tensor."""
|
|
|
|
"""Apply adam optimizer to the weight parameter using Tensor."""
|
|
|
|
success = True
|
|
|
|
success = True
|
|
|
|
if ps_parameter:
|
|
|
|
if ps_parameter:
|
|
|
|
op_shape = P.Shape()
|
|
|
|
op_shape = P.Shape()
|
|
|
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
|
|
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
|
|
|
(op_shape(params), op_shape(moment1), op_shape(moment2))), params))
|
|
|
|
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
|
|
|
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
|
|
|
eps, gradient))
|
|
|
|
eps, gradient))
|
|
|
|
return success
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
|
|
|
|
|
|
|
"Tensor", "Tensor")
|
|
|
|
|
|
|
|
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
|
|
|
|
|
|
|
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
|
|
|
|
|
|
|
|
success = True
|
|
|
|
|
|
|
|
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
|
|
|
|
|
|
|
|
success = F.depend(success, F.assign_add(param, delat_param))
|
|
|
|
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_param_value(beta1, beta2, eps, prim_name):
|
|
|
|
def _check_param_value(beta1, beta2, eps, prim_name):
|
|
|
|
"""Check the type of inputs."""
|
|
|
|
"""Check the type of inputs."""
|
|
|
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
|
|
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
|
|
@ -443,3 +454,146 @@ class AdamWeightDecay(Optimizer):
|
|
|
|
if self.use_parallel:
|
|
|
|
if self.use_parallel:
|
|
|
|
self.broadcast_params(optim_result)
|
|
|
|
self.broadcast_params(optim_result)
|
|
|
|
return optim_result
|
|
|
|
return optim_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdamOffload(Optimizer):
|
|
|
|
|
|
|
|
r"""
|
|
|
|
|
|
|
|
Updates gradients by the Adaptive Moment Estimation (Adam) algorithm. This optimizer will offload Adam optimizer to
|
|
|
|
|
|
|
|
host CPU and keep parameters being updated on the device, to minimize the memory cost. Although that would bring
|
|
|
|
|
|
|
|
about an increase of performance overhead, the optimizer could be used to run a larger model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The updating formulas are as follows,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
\begin{array}{ll} \\
|
|
|
|
|
|
|
|
m = \beta_1 * m + (1 - \beta_1) * g \\
|
|
|
|
|
|
|
|
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
|
|
|
|
|
|
|
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
|
|
|
|
|
|
|
|
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
|
|
|
|
|
|
|
|
\end{array}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
|
|
|
|
|
|
|
|
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
|
|
|
|
|
|
|
|
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
|
|
|
|
|
|
|
|
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
|
|
|
|
|
|
|
|
:math:`\epsilon` represents `eps`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
|
|
|
This optimizer only supports `GRAPH_MODE` currently.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
|
|
|
|
|
|
|
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
|
|
|
|
|
|
|
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
To improve parameter groups performance, the customized order of parameters is supported.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
|
|
|
|
|
|
|
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
|
|
|
|
|
|
|
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- params: Required. The value must be a list of `Parameter`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
|
|
|
|
|
|
|
If not, the `learning_rate` in the API will be used.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
|
|
|
|
|
|
|
will be used. If not, the `weight_decay` in the API will be used.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
|
|
|
|
|
|
|
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
|
|
|
|
|
|
|
which in the 'order_params' must be in one of group parameters.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
|
|
|
|
|
|
|
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
|
|
|
|
|
|
|
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
|
|
|
|
|
|
|
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
|
|
|
|
|
|
|
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
|
|
|
|
|
|
|
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
|
|
|
|
|
|
|
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
|
|
|
|
|
|
|
Default: 1e-3.
|
|
|
|
|
|
|
|
beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
|
|
|
|
|
|
|
Default: 0.9.
|
|
|
|
|
|
|
|
beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
|
|
|
|
|
|
|
Default: 0.999.
|
|
|
|
|
|
|
|
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
|
|
|
|
|
|
|
|
1e-8.
|
|
|
|
|
|
|
|
use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
|
|
|
|
|
|
|
|
If true, updates of the var, m, and v tensors will be protected by a lock.
|
|
|
|
|
|
|
|
If false, the result is unpredictable. Default: False.
|
|
|
|
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
|
|
|
|
If true, update the gradients using NAG.
|
|
|
|
|
|
|
|
If false, update the gradients without using NAG. Default: False.
|
|
|
|
|
|
|
|
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
|
|
|
|
|
|
|
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
|
|
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
|
|
|
Tensor[bool], the value is True.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
|
|
|
>>> #1) All parameters use the same learning rate and weight decay
|
|
|
|
|
|
|
|
>>> optim = nn.AdamOffload(params=net.trainable_params())
|
|
|
|
|
|
|
|
>>>
|
|
|
|
|
|
|
|
>>> #2) Use parameter groups and set different values
|
|
|
|
|
|
|
|
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
|
|
|
|
|
|
|
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
|
|
|
|
|
|
|
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
|
|
|
|
|
|
|
>>> {'params': no_conv_params, 'lr': 0.01},
|
|
|
|
|
|
|
|
>>> {'order_params': net.trainable_params()}]
|
|
|
|
|
|
|
|
>>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0)
|
|
|
|
|
|
|
|
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
|
|
|
|
|
|
|
>>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0.
|
|
|
|
|
|
|
|
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
|
|
|
|
|
|
|
>>>
|
|
|
|
|
|
|
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
|
|
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
|
|
|
|
|
|
|
use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
|
|
|
|
|
|
|
|
super(AdamOffload, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
|
|
|
|
|
|
|
_check_param_value(beta1, beta2, eps, self.cls_name)
|
|
|
|
|
|
|
|
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
|
|
|
|
|
|
|
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.beta1 = Tensor(beta1, mstype.float32)
|
|
|
|
|
|
|
|
self.beta2 = Tensor(beta2, mstype.float32)
|
|
|
|
|
|
|
|
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
|
|
|
|
|
|
|
|
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
|
|
|
|
|
|
|
|
self.eps = Tensor(eps, mstype.float32)
|
|
|
|
|
|
|
|
self.use_nesterov = use_nesterov
|
|
|
|
|
|
|
|
self.use_locking = use_locking
|
|
|
|
|
|
|
|
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
|
|
|
|
|
|
|
|
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
|
|
|
|
|
|
|
|
self.opt.add_prim_attr("primitive_target", "CPU")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, gradients):
|
|
|
|
|
|
|
|
params = self.parameters
|
|
|
|
|
|
|
|
moment1 = self.moment1
|
|
|
|
|
|
|
|
moment2 = self.moment2
|
|
|
|
|
|
|
|
gradients = self.decay_weight(gradients)
|
|
|
|
|
|
|
|
gradients = self.scale_grad(gradients)
|
|
|
|
|
|
|
|
lr = self.get_lr()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beta1_power = self.beta1_power * self.beta1
|
|
|
|
|
|
|
|
self.beta1_power = beta1_power
|
|
|
|
|
|
|
|
beta2_power = self.beta2_power * self.beta2
|
|
|
|
|
|
|
|
self.beta2_power = beta2_power
|
|
|
|
|
|
|
|
if self.is_group_lr:
|
|
|
|
|
|
|
|
success = self.map_(F.partial(_adam_opt, self.opt,
|
|
|
|
|
|
|
|
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
|
|
|
|
|
|
|
|
lr, gradients, params, moment1, moment2)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
success = self.map_(F.partial(_adam_opt, self.opt,
|
|
|
|
|
|
|
|
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
|
|
|
|
|
|
|
|
gradients, params, moment1, moment2)
|
|
|
|
|
|
|
|
return success
|
|
|
|