|
|
|
@ -22,6 +22,8 @@ from functools import reduce
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from ... import context
|
|
|
|
|
from ..._c_expression import signature_rw as sig_rw
|
|
|
|
|
from ..._c_expression import signature_kind as sig_kind
|
|
|
|
|
from ..._checkparam import ParamValidator as validator
|
|
|
|
|
from ..._checkparam import Rel, check_bool, check_int_positive
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|
|
|
|
filter(lambda x: x.requires_grad, net.get_parameters()))
|
|
|
|
|
>>> model = Model(net, loss, opt)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
|
|
|
|
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
|
|
|
|
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
|
|
|
|
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
|
|
|
|
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
|
|
|
|
)
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
|
|
|
|
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
|
|
|
|
validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT)
|
|
|
|
|
validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT)
|
|
|
|
|
validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE)
|
|
|
|
|
validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE)
|
|
|
|
|
validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE)
|
|
|
|
|
return v_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
|
|
|
|
validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
|
|
|
|
|
validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
|
|
|
|
|
v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
|
|
|
|
|
validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
|
|
|
|
|
validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
|
|
|
|
|
validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
|
|
|
|
return v_type
|
|
|
|
|
return g_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SmoothL1Loss(PrimitiveWithInfer):
|
|
|
|
|