|
|
|
@ -1953,7 +1953,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|
|
|
|
Using given values to update tensor value, along with the input indices.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Defaule: True.
|
|
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Default: True.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The target tensor.
|
|
|
|
@ -1995,6 +1995,53 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScatterMax(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Update the value of the input tensor through the max operation.
|
|
|
|
|
|
|
|
|
|
Using given values to update tensor value through the max operation, along with the input indices,.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
use_locking (bool): Whether protect the assignment by a lock. Default: True.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The target tensor.
|
|
|
|
|
- **indices** (Tensor) - The index to do max operation whose data type should be int.
|
|
|
|
|
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
|
|
|
|
|
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same shape and data type as `input_x`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
|
|
|
|
|
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
|
|
|
|
>>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
|
|
|
|
|
>>> scatter_max = P.ScatterMax()
|
|
|
|
|
>>> output = scatter_max(input_x, indices, update)
|
|
|
|
|
[[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=True):
|
|
|
|
|
"""Init ScatterMax"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
|
|
|
|
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
|
|
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
|
|
|
|
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
|
|
|
|
|
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
|
|
|
|
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
|
|
|
|
args = {"x": x_dtype, "updates": updates_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpaceToDepth(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
Rearrange blocks of spatial data into depth.
|
|
|
|
|