|
|
|
@ -2169,6 +2169,8 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|
|
|
|
Refer to the paper `On the importance of initialization and momentum in deep
|
|
|
|
|
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
|
|
|
|
|
|
|
|
|
|
Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage.
|
|
|
|
|
|
|
|
|
|
Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules
|
|
|
|
|
to make the data types consistent.
|
|
|
|
|
If they have different data types, lower priority data type will be converted to
|
|
|
|
@ -2194,11 +2196,14 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, parameters to be updated.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If the use_locking or use_nesterov is not a bool or gradient_scale is not a float.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Please refer to the usage in nn.ApplyMomentum.
|
|
|
|
|
Please refer to the usage in :class:`mindspore.nn.Momentum`.
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
|
|
@ -2210,6 +2215,9 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
|
|
|
|
self.use_nesterov = validator.check_bool(use_nesterov)
|
|
|
|
|
self.use_locking = validator.check_bool(use_locking)
|
|
|
|
|
validator.check_value_type('gradient_scale', gradient_scale, [float], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
|
|
|