|
|
@ -3372,11 +3372,25 @@ class GatherNd(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
class TensorScatterUpdate(PrimitiveWithInfer):
|
|
|
|
class TensorScatterUpdate(PrimitiveWithInfer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Updates tensor values using given values, along with the input indices.
|
|
|
|
Creates a new tensor by updating the positions in `input_x` indicicated by
|
|
|
|
|
|
|
|
`indices`, with values from `update`. This operation is almost equivalent to using
|
|
|
|
|
|
|
|
ScatterNd, except that the updates are applied on `input_x` instead of a zero tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
`indices` must have rank atleast 2, the last axis is the depth of each index
|
|
|
|
|
|
|
|
vectors. For each index vector, there must be a corresponding value in `update`. If
|
|
|
|
|
|
|
|
the depth of each index tensor matches the rank of `input_x`, then each index
|
|
|
|
|
|
|
|
vector corresponds to a scalar in `input_x` and each update updates a scalar. If
|
|
|
|
|
|
|
|
the depth of each index tensor is less than the rnak of `input_x`, then each index
|
|
|
|
|
|
|
|
vector corresponds to a slice in `input_x`, and each update updates a slice.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The order in which updates are applied is nondeterministic, meaning that if there
|
|
|
|
|
|
|
|
are multiple index vectors in `indices` that correspond to the same position, the
|
|
|
|
|
|
|
|
value of that position in the output will be nondeterministic.
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
|
|
|
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
|
|
|
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
|
|
|
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
|
|
|
|
|
|
|
The rank must be atleast 2.
|
|
|
|
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
|
|
|
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
|
|
|
and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
|
|
|
and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
|
|
|
|
|
|
|
|
|
|
@ -3388,7 +3402,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|
|
|
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
|
|
|
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
|
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
Supported Platforms:
|
|
|
|
``Ascend``
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
|
|
|
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
|
|
|