|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|
|
|
|
|
|
|
|
|
import mindspore
|
|
|
|
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
|
|
|
from mindspore.ops.operations import _inner_ops as inner
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from mindspore.nn.layer.container import CellList
|
|
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
|
|
@ -43,12 +44,16 @@ class Optimizer(Cell):
|
|
|
|
|
This class defines the API to add Ops to train a model. Never use
|
|
|
|
|
this class directly, but instead instantiate one of its subclasses.
|
|
|
|
|
|
|
|
|
|
Different parameter groups can set different `learning_rate` and `weight_decay`.
|
|
|
|
|
Different parameter groups can set different `learning_rate`, `weight_decay` and `grad_centralization`.
|
|
|
|
|
|
|
|
|
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
|
|
|
|
weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
|
|
|
|
|
be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
|
|
|
|
|
|
|
|
|
When separating parameter groups, if you want to centralize the gradient, set a to True, but the gradient
|
|
|
|
|
centralization can only be applied to the parameters of the convolution layer. If the parameters of the non
|
|
|
|
|
convolution layer are set to True, an error will be reported. Default: False.
|
|
|
|
|
|
|
|
|
|
To improve parameter groups performance, the customized order of parameters can be supported.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -75,6 +80,9 @@ class Optimizer(Cell):
|
|
|
|
|
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
|
|
|
|
|
in the value of 'order_params' must be in one of group parameters.
|
|
|
|
|
|
|
|
|
|
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
|
|
|
|
|
If not, the `grad_centralization` in the base class will be used.
|
|
|
|
|
|
|
|
|
|
weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0.
|
|
|
|
|
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
|
|
|
|
|
loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
|
|
|
|
@ -106,6 +114,7 @@ class Optimizer(Cell):
|
|
|
|
|
self.loss_scale = loss_scale
|
|
|
|
|
|
|
|
|
|
weight_decay = self._preprocess_weight_decay(weight_decay)
|
|
|
|
|
self.grad_centralization = False
|
|
|
|
|
|
|
|
|
|
self._unique = True
|
|
|
|
|
self._target = context.get_context("device_target")
|
|
|
|
@ -121,7 +130,8 @@ class Optimizer(Cell):
|
|
|
|
|
self.group_params = []
|
|
|
|
|
self.group_lr = []
|
|
|
|
|
self.group_weight_decay = []
|
|
|
|
|
self._init_group_params(parameters, learning_rate, weight_decay)
|
|
|
|
|
self.group_grad_centralization = []
|
|
|
|
|
self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
|
|
|
|
|
|
|
|
|
|
# The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
|
|
|
|
|
if self.dynamic_lr:
|
|
|
|
@ -129,12 +139,10 @@ class Optimizer(Cell):
|
|
|
|
|
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
|
|
|
|
|
|
|
|
|
|
if self.is_group_lr:
|
|
|
|
|
if self.dynamic_lr:
|
|
|
|
|
self.learning_rate = CellList(self.group_lr)
|
|
|
|
|
else:
|
|
|
|
|
self.learning_rate = ParameterTuple(self.group_lr)
|
|
|
|
|
self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr)
|
|
|
|
|
else:
|
|
|
|
|
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
|
|
|
|
|
|
|
|
|
|
if self.is_group:
|
|
|
|
|
self.parameters = ParameterTuple(self.group_params)
|
|
|
|
|
self.weight_decay = tuple(self.group_weight_decay)
|
|
|
|
@ -142,6 +150,7 @@ class Optimizer(Cell):
|
|
|
|
|
decay_filter = lambda x: x > 0
|
|
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
|
|
|
|
|
self.exec_weight_decay = any(self.decay_flags)
|
|
|
|
|
self.grad_centralization_flags = tuple(self.group_grad_centralization)
|
|
|
|
|
else:
|
|
|
|
|
self.parameters = ParameterTuple(parameters)
|
|
|
|
|
self.weight_decay = weight_decay * loss_scale
|
|
|
|
@ -163,6 +172,10 @@ class Optimizer(Cell):
|
|
|
|
|
self.global_step_increase_tensor = Tensor(1, mstype.int32)
|
|
|
|
|
self.param_length = len(self.parameters)
|
|
|
|
|
self.map_ = C.Map()
|
|
|
|
|
self._use_parallel_optimizer()
|
|
|
|
|
|
|
|
|
|
def _use_parallel_optimizer(self):
|
|
|
|
|
"""Indicates whether to use automatic parallelism."""
|
|
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
|
|
|
|
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
|
|
|
|
|
self.use_parallel = True
|
|
|
|
@ -187,7 +200,6 @@ class Optimizer(Cell):
|
|
|
|
|
self.param_names = []
|
|
|
|
|
for param in self.parameters:
|
|
|
|
|
self.param_names.append(param.name)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.optim_filter = (True,) * self.param_length
|
|
|
|
|
|
|
|
|
@ -239,6 +251,25 @@ class Optimizer(Cell):
|
|
|
|
|
|
|
|
|
|
return gradients
|
|
|
|
|
|
|
|
|
|
def gradients_centralization(self, gradients):
|
|
|
|
|
"""
|
|
|
|
|
Gradients centralization.
|
|
|
|
|
|
|
|
|
|
A method for optimizing convolutional layer parameters to impore the training speed of a deep learning neural
|
|
|
|
|
network model.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
|
|
|
|
|
`self.parameters`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple[Tensor], The gradients after gradients centralization.
|
|
|
|
|
"""
|
|
|
|
|
if self.is_group:
|
|
|
|
|
gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients)
|
|
|
|
|
|
|
|
|
|
return gradients
|
|
|
|
|
|
|
|
|
|
def scale_grad(self, gradients):
|
|
|
|
|
"""
|
|
|
|
|
Loss scale for mixed precision.
|
|
|
|
@ -273,6 +304,11 @@ class Optimizer(Cell):
|
|
|
|
|
return weight_decay
|
|
|
|
|
raise TypeError("Weight decay should be int or float.")
|
|
|
|
|
|
|
|
|
|
def _preprocess_grad_centralization(self, grad_centralization):
|
|
|
|
|
if not isinstance(grad_centralization, bool):
|
|
|
|
|
raise TypeError("The gradients centralization should be bool")
|
|
|
|
|
return grad_centralization
|
|
|
|
|
|
|
|
|
|
def _preprocess_single_lr(self, learning_rate):
|
|
|
|
|
"""Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
|
|
|
|
|
if isinstance(learning_rate, (float, int)):
|
|
|
|
@ -315,7 +351,7 @@ class Optimizer(Cell):
|
|
|
|
|
|
|
|
|
|
def _check_group_params(self, parameters):
|
|
|
|
|
"""Check group params."""
|
|
|
|
|
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
|
|
|
|
|
parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
|
|
|
|
|
for group_param in parameters:
|
|
|
|
|
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
|
|
|
|
|
if invalid_key:
|
|
|
|
@ -365,8 +401,8 @@ class Optimizer(Cell):
|
|
|
|
|
elif group_lr_length != tensor_lr_length:
|
|
|
|
|
raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
|
|
|
|
|
|
|
|
|
|
def _init_group_params(self, parameters, learning_rate, weight_decay):
|
|
|
|
|
"""Initialize learning rate or weight decay in group params."""
|
|
|
|
|
def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization):
|
|
|
|
|
"""Initialize learning rate, weight decay or grad centralization in group params."""
|
|
|
|
|
self._parse_group_params(parameters, learning_rate)
|
|
|
|
|
default_lr = self._build_single_lr(learning_rate, 'learning_rate')
|
|
|
|
|
|
|
|
|
@ -391,8 +427,20 @@ class Optimizer(Cell):
|
|
|
|
|
else:
|
|
|
|
|
weight_decay_ = weight_decay * self.loss_scale
|
|
|
|
|
|
|
|
|
|
if 'grad_centralization' in group_param.keys():
|
|
|
|
|
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
|
|
|
|
|
for param in group_param['params']:
|
|
|
|
|
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
|
|
|
|
|
if "conv" not in param.name and self.grad_centralization is True:
|
|
|
|
|
raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter"
|
|
|
|
|
"is not a convolution layer, this parameter cannot be set to True.")
|
|
|
|
|
|
|
|
|
|
grad_centralization_ = self.grad_centralization
|
|
|
|
|
else:
|
|
|
|
|
grad_centralization_ = grad_centralization
|
|
|
|
|
|
|
|
|
|
for key in group_param.keys():
|
|
|
|
|
if key not in ('params', 'lr', 'weight_decay'):
|
|
|
|
|
if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'):
|
|
|
|
|
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
|
|
|
|
|
|
|
|
|
|
for param in group_param['params']:
|
|
|
|
@ -403,13 +451,14 @@ class Optimizer(Cell):
|
|
|
|
|
params_store.append(param.name)
|
|
|
|
|
self.group_lr.append(lr)
|
|
|
|
|
self.group_weight_decay.append(weight_decay_)
|
|
|
|
|
self.group_grad_centralization.append(grad_centralization_)
|
|
|
|
|
|
|
|
|
|
if self.is_group_params_ordered:
|
|
|
|
|
self._order_and_adjust_group_params(ordered_parameters)
|
|
|
|
|
|
|
|
|
|
def _order_and_adjust_group_params(self, ordered_parameters):
|
|
|
|
|
"""
|
|
|
|
|
Order group parameter, learning rate and weight decay in group params.
|
|
|
|
|
Order group parameter, learning rate, weight decay and grad centralization in group params.
|
|
|
|
|
"""
|
|
|
|
|
params_length = len(self.group_params)
|
|
|
|
|
if len(ordered_parameters) != len(self.group_params):
|
|
|
|
@ -418,17 +467,21 @@ class Optimizer(Cell):
|
|
|
|
|
ordered_params = [None] * params_length
|
|
|
|
|
ordered_learning_rate = [None] * params_length
|
|
|
|
|
ordered_weight_decay = [None] * params_length
|
|
|
|
|
ordered_grad_centralization = [None] * params_length
|
|
|
|
|
params_name = [param.name for param in ordered_parameters]
|
|
|
|
|
|
|
|
|
|
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
|
|
|
|
|
for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay,
|
|
|
|
|
self.group_grad_centralization):
|
|
|
|
|
index = params_name.index(param.name)
|
|
|
|
|
ordered_params[index] = param
|
|
|
|
|
ordered_learning_rate[index] = lr
|
|
|
|
|
ordered_weight_decay[index] = wd
|
|
|
|
|
ordered_grad_centralization[index] = gc
|
|
|
|
|
|
|
|
|
|
self.group_params = ordered_params
|
|
|
|
|
self.group_lr = ordered_learning_rate
|
|
|
|
|
self.group_weight_decay = ordered_weight_decay
|
|
|
|
|
self.group_grad_centralization = ordered_grad_centralization
|
|
|
|
|
|
|
|
|
|
def get_lr(self):
|
|
|
|
|
"""
|
|
|
|
@ -535,8 +588,10 @@ class Optimizer(Cell):
|
|
|
|
|
op_add = P.AddN()
|
|
|
|
|
op_gather = P.Gather()
|
|
|
|
|
op_mul = P.Mul()
|
|
|
|
|
op_gc = inner.Centralization()
|
|
|
|
|
|
|
|
|
|
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
|
|
|
|
_apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
|
|
|
|
@ -558,9 +613,18 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
return gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_apply_grad_centralization.register("Bool", "Tensor")
|
|
|
|
|
def _tensor_apply_grad_centralization(if_apply, gradient):
|
|
|
|
|
"""Get grad with grad_centralization."""
|
|
|
|
|
if if_apply:
|
|
|
|
|
return op_gc(gradient, -1)
|
|
|
|
|
return gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|
|
|
|
_indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Number", "Tensor")
|
|
|
|
|
def tensor_grad_scale(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
@ -568,11 +632,13 @@ def tensor_grad_scale(scale, grad):
|
|
|
|
|
return grad
|
|
|
|
|
return op_mul(grad, F.cast(scale, F.dtype(grad)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "Tensor")
|
|
|
|
|
def tensor_grad_scale_with_tensor(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
return op_mul(grad, F.cast(scale, F.dtype(grad)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "RowTensor")
|
|
|
|
|
def tensor_grad_scale_with_sparse(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|