|
|
|
@ -69,6 +69,37 @@ class _ScatterOp(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ScatterOp_Dynamic(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Defines Scatter operators with dynamic shape
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
|
|
|
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
|
|
|
|
sig.make_sig('updates', dtype=sig.sig_dtype.T)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
|
|
|
|
|
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
|
|
|
|
raise ValueError(f"For '{prim_name}', "
|
|
|
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
|
|
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
"""Initialize _ScatterOp_Dynamic"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
|
|
|
|
|
def check_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
|
|
|
|
|
args = {"x": x_dtype, "updates": updates_dtype}
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ScatterNdOp(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Defines _ScatterNd operators
|
|
|
|
@ -2723,7 +2754,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterUpdate(_ScatterOp):
|
|
|
|
|
class ScatterUpdate(_ScatterOp_Dynamic):
|
|
|
|
|
"""
|
|
|
|
|
Updates tensor value by using input indices and value.
|
|
|
|
|
|
|
|
|
@ -2757,20 +2788,12 @@ class ScatterUpdate(_ScatterOp):
|
|
|
|
|
[[2.0, 1.2, 1.0],
|
|
|
|
|
[3.0, 1.2, 1.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=True):
|
|
|
|
|
"""Initialize ScatterUpdate"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
|
|
|
|
|
args = {"x": x_dtype, "value": value_dtype}
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterNdUpdate(_ScatterNdOp):
|
|
|
|
|
"""
|
|
|
|
|
Updates tensor value by using input indices and value.
|
|
|
|
@ -2891,7 +2914,7 @@ class ScatterMin(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterAdd(_ScatterOp):
|
|
|
|
|
class ScatterAdd(_ScatterOp_Dynamic):
|
|
|
|
|
"""
|
|
|
|
|
Updates the value of the input tensor through the add operation.
|
|
|
|
|
|
|
|
|
@ -2923,6 +2946,11 @@ class ScatterAdd(_ScatterOp):
|
|
|
|
|
>>> output = scatter_add(input_x, indices, updates)
|
|
|
|
|
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
|
|
|
|
|
"""
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
"""Initialize ScatterAdd"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterSub(_ScatterOp):
|
|
|
|
|