add grad centralization

pull/12828/head
Jiaqi 4 years ago
parent ca6586715b
commit af3a7a30f8

@ -67,6 +67,10 @@ class Adagrad(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
@ -98,12 +102,14 @@ class Adagrad(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.Adagrad(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.Adagrad(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -124,6 +130,7 @@ class Adagrad(Optimizer):
accum = self.accum accum = self.accum
grads = self.decay_weight(grads) grads = self.decay_weight(grads)
grads = self.scale_grad(grads) grads = self.scale_grad(grads)
grads = self.gradients_centralization(grads)
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,

@ -235,6 +235,10 @@ class Adam(Optimizer):
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters. which in the 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
@ -275,12 +279,14 @@ class Adam(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -320,6 +326,7 @@ class Adam(Optimizer):
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients) gradients = self._grad_sparse_indices_deduplicate(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr() lr = self.get_lr()
beta1_power = self.beta1_power * self.beta1 beta1_power = self.beta1_power * self.beta1

@ -97,7 +97,7 @@ class FTRL(Optimizer):
\end{cases}\\ \end{cases}\\
\end{array} \end{array}
:math:`m` represents `accum`, :math:`g` represents `grads`, :math:`t` represents updateing step, :math:`m` represents `accum`, :math:`g` represents `grads`, :math:`t` represents updating step,
:math:`u` represents `linear`, :math:`p` represents `lr_power`, :math:`\alpha` represents `learning_rate`, :math:`u` represents `linear`, :math:`p` represents `lr_power`, :math:`\alpha` represents `learning_rate`,
:math:`\omega` represents `params`. :math:`\omega` represents `params`.
@ -128,6 +128,10 @@ class FTRL(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used.This parameter only works on the
convolution layer.
initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently
not supported. Default: 0.001. not supported. Default: 0.001.
@ -157,12 +161,13 @@ class FTRL(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params}, ... {'params': no_conv_params},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use default weight decay of 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use default weight decay of 0.0 and grad centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -201,6 +206,7 @@ class FTRL(Optimizer):
grads = self.decay_weight(grads) grads = self.decay_weight(grads)
grads = self.scale_grad(grads) grads = self.scale_grad(grads)
grads = self._grad_sparse_indices_deduplicate(grads) grads = self._grad_sparse_indices_deduplicate(grads)
grads = self.gradients_centralization(grads)
lr = self.get_lr() lr = self.get_lr()
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,

@ -200,6 +200,10 @@ class Lamb(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
@ -234,13 +238,14 @@ class Lamb(Optimizer):
... decay_steps=4, power = 0.5) ... decay_steps=4, power = 0.5)
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': poly_decay_lr}, ... {'params': no_conv_params, 'lr': poly_decay_lr},
... {'order_params': net.trainable_params(0.01)}] ... {'order_params': net.trainable_params(0.01)}]
>>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # centralization of True.
>>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default
>>> # weight decay of 0.0. >>> # weight decay of 0.0 and grad centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -268,6 +273,7 @@ class Lamb(Optimizer):
def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
gradients = self.gradients_centralization(gradients)
if self.is_group: if self.is_group:
if self.is_group_lr: if self.is_group_lr:
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,

@ -154,6 +154,10 @@ class LazyAdam(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
@ -195,12 +199,14 @@ class LazyAdam(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -237,6 +243,7 @@ class LazyAdam(Optimizer):
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients) gradients = self._grad_sparse_indices_deduplicate(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr() lr = self.get_lr()
self.beta1_power = self.beta1_power * self.beta1 self.beta1_power = self.beta1_power * self.beta1

@ -83,6 +83,10 @@ class Momentum(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
@ -117,12 +121,14 @@ class Momentum(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
>>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01 and
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # grad centralization of True.
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0
>>> # and grad centralization of False..
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -145,6 +151,7 @@ class Momentum(Optimizer):
moments = self.moments moments = self.moments
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,

@ -19,6 +19,7 @@ import numpy as np
import mindspore import mindspore
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.container import CellList from mindspore.nn.layer.container import CellList
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
@ -43,12 +44,16 @@ class Optimizer(Cell):
This class defines the API to add Ops to train a model. Never use This class defines the API to add Ops to train a model. Never use
this class directly, but instead instantiate one of its subclasses. this class directly, but instead instantiate one of its subclasses.
Different parameter groups can set different `learning_rate` and `weight_decay`. Different parameter groups can set different `learning_rate`, `weight_decay` and `grad_centralization`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
When separating parameter groups, if you want to centralize the gradient, set a to True, but the gradient
centralization can only be applied to the parameters of the convolution layer. If the parameters of the non
convolution layer are set to True, an error will be reported. Default: False.
To improve parameter groups performance, the customized order of parameters can be supported. To improve parameter groups performance, the customized order of parameters can be supported.
Args: Args:
@ -75,6 +80,9 @@ class Optimizer(Cell):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used.
weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0. weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
@ -106,6 +114,7 @@ class Optimizer(Cell):
self.loss_scale = loss_scale self.loss_scale = loss_scale
weight_decay = self._preprocess_weight_decay(weight_decay) weight_decay = self._preprocess_weight_decay(weight_decay)
self.grad_centralization = False
self._unique = True self._unique = True
self._target = context.get_context("device_target") self._target = context.get_context("device_target")
@ -121,7 +130,8 @@ class Optimizer(Cell):
self.group_params = [] self.group_params = []
self.group_lr = [] self.group_lr = []
self.group_weight_decay = [] self.group_weight_decay = []
self._init_group_params(parameters, learning_rate, weight_decay) self.group_grad_centralization = []
self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
# The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
if self.dynamic_lr: if self.dynamic_lr:
@ -129,12 +139,10 @@ class Optimizer(Cell):
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
if self.is_group_lr: if self.is_group_lr:
if self.dynamic_lr: self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr)
self.learning_rate = CellList(self.group_lr)
else:
self.learning_rate = ParameterTuple(self.group_lr)
else: else:
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
if self.is_group: if self.is_group:
self.parameters = ParameterTuple(self.group_params) self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay) self.weight_decay = tuple(self.group_weight_decay)
@ -142,6 +150,7 @@ class Optimizer(Cell):
decay_filter = lambda x: x > 0 decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
self.exec_weight_decay = any(self.decay_flags) self.exec_weight_decay = any(self.decay_flags)
self.grad_centralization_flags = tuple(self.group_grad_centralization)
else: else:
self.parameters = ParameterTuple(parameters) self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale self.weight_decay = weight_decay * loss_scale
@ -163,6 +172,10 @@ class Optimizer(Cell):
self.global_step_increase_tensor = Tensor(1, mstype.int32) self.global_step_increase_tensor = Tensor(1, mstype.int32)
self.param_length = len(self.parameters) self.param_length = len(self.parameters)
self.map_ = C.Map() self.map_ = C.Map()
self._use_parallel_optimizer()
def _use_parallel_optimizer(self):
"""Indicates whether to use automatic parallelism."""
if context.get_auto_parallel_context("enable_parallel_optimizer"): if context.get_auto_parallel_context("enable_parallel_optimizer"):
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend": if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
self.use_parallel = True self.use_parallel = True
@ -187,7 +200,6 @@ class Optimizer(Cell):
self.param_names = [] self.param_names = []
for param in self.parameters: for param in self.parameters:
self.param_names.append(param.name) self.param_names.append(param.name)
else: else:
self.optim_filter = (True,) * self.param_length self.optim_filter = (True,) * self.param_length
@ -239,6 +251,25 @@ class Optimizer(Cell):
return gradients return gradients
def gradients_centralization(self, gradients):
"""
Gradients centralization.
A method for optimizing convolutional layer parameters to impore the training speed of a deep learning neural
network model.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
`self.parameters`.
Returns:
tuple[Tensor], The gradients after gradients centralization.
"""
if self.is_group:
gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients)
return gradients
def scale_grad(self, gradients): def scale_grad(self, gradients):
""" """
Loss scale for mixed precision. Loss scale for mixed precision.
@ -273,6 +304,11 @@ class Optimizer(Cell):
return weight_decay return weight_decay
raise TypeError("Weight decay should be int or float.") raise TypeError("Weight decay should be int or float.")
def _preprocess_grad_centralization(self, grad_centralization):
if not isinstance(grad_centralization, bool):
raise TypeError("The gradients centralization should be bool")
return grad_centralization
def _preprocess_single_lr(self, learning_rate): def _preprocess_single_lr(self, learning_rate):
"""Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
if isinstance(learning_rate, (float, int)): if isinstance(learning_rate, (float, int)):
@ -315,7 +351,7 @@ class Optimizer(Cell):
def _check_group_params(self, parameters): def _check_group_params(self, parameters):
"""Check group params.""" """Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
for group_param in parameters: for group_param in parameters:
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
if invalid_key: if invalid_key:
@ -365,8 +401,8 @@ class Optimizer(Cell):
elif group_lr_length != tensor_lr_length: elif group_lr_length != tensor_lr_length:
raise ValueError("The Tensor type dynamic learning rate in group should be the same size.") raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
def _init_group_params(self, parameters, learning_rate, weight_decay): def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization):
"""Initialize learning rate or weight decay in group params.""" """Initialize learning rate, weight decay or grad centralization in group params."""
self._parse_group_params(parameters, learning_rate) self._parse_group_params(parameters, learning_rate)
default_lr = self._build_single_lr(learning_rate, 'learning_rate') default_lr = self._build_single_lr(learning_rate, 'learning_rate')
@ -391,8 +427,20 @@ class Optimizer(Cell):
else: else:
weight_decay_ = weight_decay * self.loss_scale weight_decay_ = weight_decay * self.loss_scale
if 'grad_centralization' in group_param.keys():
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
for param in group_param['params']:
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
if "conv" not in param.name and self.grad_centralization is True:
raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter"
"is not a convolution layer, this parameter cannot be set to True.")
grad_centralization_ = self.grad_centralization
else:
grad_centralization_ = grad_centralization
for key in group_param.keys(): for key in group_param.keys():
if key not in ('params', 'lr', 'weight_decay'): if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'):
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
for param in group_param['params']: for param in group_param['params']:
@ -403,13 +451,14 @@ class Optimizer(Cell):
params_store.append(param.name) params_store.append(param.name)
self.group_lr.append(lr) self.group_lr.append(lr)
self.group_weight_decay.append(weight_decay_) self.group_weight_decay.append(weight_decay_)
self.group_grad_centralization.append(grad_centralization_)
if self.is_group_params_ordered: if self.is_group_params_ordered:
self._order_and_adjust_group_params(ordered_parameters) self._order_and_adjust_group_params(ordered_parameters)
def _order_and_adjust_group_params(self, ordered_parameters): def _order_and_adjust_group_params(self, ordered_parameters):
""" """
Order group parameter, learning rate and weight decay in group params. Order group parameter, learning rate, weight decay and grad centralization in group params.
""" """
params_length = len(self.group_params) params_length = len(self.group_params)
if len(ordered_parameters) != len(self.group_params): if len(ordered_parameters) != len(self.group_params):
@ -418,17 +467,21 @@ class Optimizer(Cell):
ordered_params = [None] * params_length ordered_params = [None] * params_length
ordered_learning_rate = [None] * params_length ordered_learning_rate = [None] * params_length
ordered_weight_decay = [None] * params_length ordered_weight_decay = [None] * params_length
ordered_grad_centralization = [None] * params_length
params_name = [param.name for param in ordered_parameters] params_name = [param.name for param in ordered_parameters]
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay): for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay,
self.group_grad_centralization):
index = params_name.index(param.name) index = params_name.index(param.name)
ordered_params[index] = param ordered_params[index] = param
ordered_learning_rate[index] = lr ordered_learning_rate[index] = lr
ordered_weight_decay[index] = wd ordered_weight_decay[index] = wd
ordered_grad_centralization[index] = gc
self.group_params = ordered_params self.group_params = ordered_params
self.group_lr = ordered_learning_rate self.group_lr = ordered_learning_rate
self.group_weight_decay = ordered_weight_decay self.group_weight_decay = ordered_weight_decay
self.group_grad_centralization = ordered_grad_centralization
def get_lr(self): def get_lr(self):
""" """
@ -535,8 +588,10 @@ class Optimizer(Cell):
op_add = P.AddN() op_add = P.AddN()
op_gather = P.Gather() op_gather = P.Gather()
op_mul = P.Mul() op_mul = P.Mul()
op_gc = inner.Centralization()
_apply_decay = C.MultitypeFuncGraph("apply_decay") _apply_decay = C.MultitypeFuncGraph("apply_decay")
_apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor") @_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
@ -558,9 +613,18 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
return gradient return gradient
@_apply_grad_centralization.register("Bool", "Tensor")
def _tensor_apply_grad_centralization(if_apply, gradient):
"""Get grad with grad_centralization."""
if if_apply:
return op_gc(gradient, -1)
return gradient
_grad_scale = C.MultitypeFuncGraph("grad_scale") _grad_scale = C.MultitypeFuncGraph("grad_scale")
_indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate") _indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
@_grad_scale.register("Number", "Tensor") @_grad_scale.register("Number", "Tensor")
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
@ -568,11 +632,13 @@ def tensor_grad_scale(scale, grad):
return grad return grad
return op_mul(grad, F.cast(scale, F.dtype(grad))) return op_mul(grad, F.cast(scale, F.dtype(grad)))
@_grad_scale.register("Tensor", "Tensor") @_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale_with_tensor(scale, grad): def tensor_grad_scale_with_tensor(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
return op_mul(grad, F.cast(scale, F.dtype(grad))) return op_mul(grad, F.cast(scale, F.dtype(grad)))
@_grad_scale.register("Tensor", "RowTensor") @_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_with_sparse(scale, grad): def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""

@ -85,6 +85,10 @@ class ProximalAdagrad(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
@ -118,12 +122,14 @@ class ProximalAdagrad(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -148,6 +154,7 @@ class ProximalAdagrad(Optimizer):
grads = self.decay_weight(grads) grads = self.decay_weight(grads)
grads = self.scale_grad(grads) grads = self.scale_grad(grads)
grads = self._grad_sparse_indices_deduplicate(grads) grads = self._grad_sparse_indices_deduplicate(grads)
grads = self.gradients_centralization(grads)
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,

@ -105,6 +105,10 @@ class RMSProp(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
@ -141,12 +145,14 @@ class RMSProp(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -182,6 +188,7 @@ class RMSProp(Optimizer):
params = self.parameters params = self.parameters
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.centered: if self.centered:
if self.is_group_lr: if self.is_group_lr:

@ -80,6 +80,10 @@ class SGD(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters. in the value of 'order_params' must be in one of group parameters.
- grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
If not, the `grad_centralization` in the base class will be used. This parameter only works on the
convolution layer.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
@ -116,12 +120,14 @@ class SGD(Optimizer):
>>> #2) Use parameter groups and set different values >>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
... {'params': no_conv_params, 'lr': 0.01}, ... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}] ... {'order_params': net.trainable_params()}]
>>> optim = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) >>> optim = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
@ -171,6 +177,7 @@ class SGD(Optimizer):
accum = self.accum accum = self.accum
stat = self.stat stat = self.stat
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.is_group_lr: if self.is_group_lr:
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)

Loading…
Cancel
Save