|
|
|
@ -17,6 +17,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore._checkparam import check_bool
|
|
|
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
|
|
|
|
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
|
|
|
@ -67,6 +68,7 @@ class Momentum(Optimizer):
|
|
|
|
|
momentum (float): Hyperparameter of type float, means momentum for the moving average.
|
|
|
|
|
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
|
|
|
|
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
|
|
|
|
use_nesterov (bool): Enable Nesterov momentum. Default: False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
|
|
@ -95,15 +97,16 @@ class Momentum(Optimizer):
|
|
|
|
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0):
|
|
|
|
|
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
|
|
|
|
|
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
|
|
|
|
if isinstance(momentum, float) and momentum < 0.0:
|
|
|
|
|
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
|
|
|
|
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
|
|
|
|
self.params = self.parameters
|
|
|
|
|
self.use_nesterov = check_bool(use_nesterov)
|
|
|
|
|
self.moments = self.params.clone(prefix="moments", init='zeros')
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self.opt = P.ApplyMomentum()
|
|
|
|
|
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
|
|
|
|
|
|
|
|
|
def construct(self, gradients):
|
|
|
|
|
params = self.params
|
|
|
|
|