|
|
@ -1507,8 +1507,11 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|
|
|
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
|
|
|
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
|
|
|
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
|
|
|
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
|
|
|
outputs=['output'])
|
|
|
|
outputs=['output'])
|
|
|
|
|
|
|
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
|
|
|
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
|
|
|
|
|
|
|
if self.is_tbe:
|
|
|
|
|
|
|
|
return v_shape, v_shape
|
|
|
|
return v_shape
|
|
|
|
return v_shape
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
|
|
|
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
|
|
@ -1519,6 +1522,8 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
|
|
|
|
|
|
|
|
if self.is_tbe:
|
|
|
|
|
|
|
|
return g_dtype, g_dtype
|
|
|
|
return g_dtype
|
|
|
|
return g_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -2810,13 +2815,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|
|
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
|
|
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
|
|
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
|
|
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
|
|
|
return var_shape
|
|
|
|
return var_shape, accum_shape
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
|
|
|
|
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
|
|
|
|
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
|
|
|
|
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
|
|
|
|
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
|
|
|
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
|
|
|
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
|
|
|
|
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
|
|
|
|
return var_type
|
|
|
|
return var_type, accum_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|
|
|
class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|
|
@ -3074,11 +3079,14 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
|
|
|
outputs=['output'])
|
|
|
|
outputs=['output'])
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
|
|
|
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
|
|
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
|
|
|
lr_power_shape):
|
|
|
|
lr_power_shape):
|
|
|
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
|
|
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
|
|
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
|
|
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
if self.is_tbe:
|
|
|
|
|
|
|
|
return var_shape, var_shape, var_shape
|
|
|
|
return var_shape
|
|
|
|
return var_shape
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
|
|
|
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
|
|
@ -3090,6 +3098,8 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
|
|
|
|
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
|
|
|
|
|
|
|
|
if self.is_tbe:
|
|
|
|
|
|
|
|
return var_type, var_type, var_type
|
|
|
|
return var_type
|
|
|
|
return var_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|