|
|
|
@ -2433,7 +2433,10 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|
|
|
|
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same shape and type as `var`.
|
|
|
|
|
Tuple of 2 Tensor, the updated parameters.
|
|
|
|
|
|
|
|
|
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
|
|
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -2448,13 +2451,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|
|
|
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
|
|
|
|
return var_shape
|
|
|
|
|
return var_shape, accum_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
|
|
|
|
|
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
|
|
|
|
|
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
|
|
|
|
|
return var_type
|
|
|
|
|
return var_type, accum_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LARSUpdate(PrimitiveWithInfer):
|
|
|
|
|