|
|
|
@ -38,6 +38,39 @@ from ..._c_expression import signature_dtype as sig_dtype
|
|
|
|
|
from ..._c_expression import typing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ScatterOp(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Define Scatter operators
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
|
|
|
|
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
|
|
|
|
)
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
|
|
|
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
|
|
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
|
|
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
|
|
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
"""Init _ScatterOp"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
_ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
|
|
|
|
args = {"x": x_dtype, "updates": updates_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
|
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
|
|
|
|
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
|
|
|
@ -2221,7 +2254,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterUpdate(PrimitiveWithInfer):
|
|
|
|
|
class ScatterUpdate(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update tensor value by using input indices and value.
|
|
|
|
|
|
|
|
|
@ -2233,8 +2266,8 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
|
|
|
|
- **indices** (Tensor) - The index of input tensor. With int32 data type.
|
|
|
|
|
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
|
|
|
|
and update.shape = indices.shape + input_x.shape[1:].
|
|
|
|
|
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
|
|
|
|
and updates.shape = indices.shape + input_x.shape[1:].
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same shape and type as `input_x`.
|
|
|
|
@ -2243,27 +2276,17 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|
|
|
|
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
|
|
|
|
|
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
|
|
|
|
|
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
|
|
|
|
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
|
|
|
|
|
>>> update = Tensor(np_update, mindspore.float32)
|
|
|
|
|
>>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
|
|
|
|
|
>>> updates = Tensor(np_updates, mindspore.float32)
|
|
|
|
|
>>> op = P.ScatterUpdate()
|
|
|
|
|
>>> output = op(input_x, indices, update)
|
|
|
|
|
>>> output = op(input_x, indices, updates)
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
|
|
|
|
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=True):
|
|
|
|
|
"""Init ScatterUpdate"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, value_shape):
|
|
|
|
|
if indices_shape + x_shape[1:] != value_shape:
|
|
|
|
|
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
|
|
|
|
|
return x_shape
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
|
|
|
@ -2323,14 +2346,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
|
|
|
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
|
|
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
|
|
|
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
|
|
|
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterMax(PrimitiveWithInfer):
|
|
|
|
|
class ScatterMax(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the max operation.
|
|
|
|
|
|
|
|
|
@ -2364,18 +2380,8 @@ class ScatterMax(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
|
|
|
|
args = {"x": x_dtype, "updates": updates_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterMin(PrimitiveWithInfer):
|
|
|
|
|
class ScatterMin(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the min operation.
|
|
|
|
|
|
|
|
|
@ -2403,24 +2409,8 @@ class ScatterMin(PrimitiveWithInfer):
|
|
|
|
|
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
"""Init ScatterMin"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
|
|
|
|
args = {"x": x_dtype, "updates": updates_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterAdd(PrimitiveWithInfer):
|
|
|
|
|
class ScatterAdd(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the add operation.
|
|
|
|
|
|
|
|
|
@ -2448,23 +2438,8 @@ class ScatterAdd(PrimitiveWithInfer):
|
|
|
|
|
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
"""Init ScatterAdd"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
|
|
|
|
args = {'x': x_dtype, 'updates': updates_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterSub(PrimitiveWithInfer):
|
|
|
|
|
class ScatterSub(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the sub operation.
|
|
|
|
|
|
|
|
|
@ -2492,20 +2467,63 @@ class ScatterSub(PrimitiveWithInfer):
|
|
|
|
|
[[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
"""Init ScatterSub"""
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
class ScatterMul(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the mul operation.
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
|
|
|
|
args = {'x': x_dtype, 'updates': updates_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
Using given values to update tensor value through the mul operation, along with the input indices.
|
|
|
|
|
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Parameter) - The target parameter.
|
|
|
|
|
- **indices** (Tensor) - The index to do mul operation whose data type should be mindspore.int32.
|
|
|
|
|
- **updates** (Tensor) - The tensor doing the mul operation with `input_x`,
|
|
|
|
|
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Parameter, the updated `input_x`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
|
|
|
|
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
|
|
|
|
|
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
|
|
|
|
|
>>> scatter_mul = P.ScatterMul()
|
|
|
|
|
>>> output = scatter_mul(input_x, indices, updates)
|
|
|
|
|
[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterDiv(_ScatterOp):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the div operation.
|
|
|
|
|
|
|
|
|
|
Using given values to update tensor value through the div operation, along with the input indices.
|
|
|
|
|
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Parameter) - The target parameter.
|
|
|
|
|
- **indices** (Tensor) - The index to do div operation whose data type should be mindspore.int32.
|
|
|
|
|
- **updates** (Tensor) - The tensor doing the div operation with `input_x`,
|
|
|
|
|
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Parameter, the updated `input_x`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
|
|
|
|
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
|
|
|
|
|
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
|
|
|
|
|
>>> scatter_div = P.ScatterDiv()
|
|
|
|
|
>>> output = scatter_div(input_x, indices, updates)
|
|
|
|
|
[[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpaceToDepth(PrimitiveWithInfer):
|
|
|
|
|