|
|
|
@ -42,51 +42,51 @@ class RMSProp(Optimizer):
|
|
|
|
|
"""
|
|
|
|
|
Implements Root Mean Squared Propagation (RMSProp) algorithm.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
|
|
|
|
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
|
|
|
|
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
|
|
|
|
Update `params` according to the RMSProp algorithm.
|
|
|
|
|
|
|
|
|
|
To improve parameter groups performance, the customized order of parameters can be supported.
|
|
|
|
|
The equation is as follows:
|
|
|
|
|
|
|
|
|
|
Update `params` according to the RMSProp algorithm.
|
|
|
|
|
.. math::
|
|
|
|
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
|
|
|
|
|
|
|
|
|
The equation is as follows:
|
|
|
|
|
.. math::
|
|
|
|
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
|
|
|
|
.. math::
|
|
|
|
|
w = w - m_{t}
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
|
|
|
|
|
The first equation calculates moving average of the squared gradient for
|
|
|
|
|
each weight. Then dividing the gradient by :math:`\\sqrt{ms_{t} + \\epsilon}`.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
w = w - m_{t}
|
|
|
|
|
if centered is True:
|
|
|
|
|
|
|
|
|
|
The first equation calculates moving average of the squared gradient for
|
|
|
|
|
each weight. Then dividing the gradient by :math:`\\sqrt{ms_{t} + \\epsilon}`.
|
|
|
|
|
.. math::
|
|
|
|
|
g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
|
|
|
|
|
|
|
|
|
|
if centered is True:
|
|
|
|
|
.. math::
|
|
|
|
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
|
|
|
|
|
.. math::
|
|
|
|
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
|
|
|
|
|
.. math::
|
|
|
|
|
w = w - m_{t}
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
|
|
|
|
|
where :math:`w` represents `params`, which will be updated.
|
|
|
|
|
:math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`.
|
|
|
|
|
:math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`,
|
|
|
|
|
:math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`.
|
|
|
|
|
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
|
|
|
|
|
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
|
|
|
|
|
:math:`\\eta` is learning rate, represents `learning_rate`. :math:`\\nabla Q_{i}(w)` is gradientse,
|
|
|
|
|
represents `gradients`.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
w = w - m_{t}
|
|
|
|
|
Note:
|
|
|
|
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
|
|
|
|
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
|
|
|
|
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
|
|
|
|
|
|
|
|
|
where :math:`w` represents `params`, which will be updated.
|
|
|
|
|
:math:`g_{t}` is mean gradients, :math:`g_{t-1}` is the last moment of :math:`g_{t}`.
|
|
|
|
|
:math:`s_{t}` is the mean square gradients, :math:`s_{t-1}` is the last moment of :math:`s_{t}`,
|
|
|
|
|
:math:`m_{t}` is moment, the delta of `w`, :math:`m_{t-1}` is the last moment of :math:`m_{t}`.
|
|
|
|
|
:math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
|
|
|
|
|
:math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
|
|
|
|
|
:math:`\\eta` is learning rate, represents `learning_rate`. :math:`\\nabla Q_{i}(w)` is gradientse,
|
|
|
|
|
represents `gradients`.
|
|
|
|
|
To improve parameter groups performance, the customized order of parameters can be supported.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
|
|
|
|