|
|
|
@ -806,9 +806,11 @@ class Unique(Primitive):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Gather(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
r"""
|
|
|
|
|
Returns a slice of the input tensor based on the specified indices and axis.
|
|
|
|
|
|
|
|
|
|
Slices the input tensor base on the indices at specified axis. See the following example for more clear.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
|
|
|
The original Tensor.
|
|
|
|
@ -818,7 +820,8 @@ class Gather(PrimitiveWithCheck):
|
|
|
|
|
- **axis** (int) - Specifies the dimension index to gather indices.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
|
|
|
|
Tensor, the shape of tensor is
|
|
|
|
|
:math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If `axis` is not an int.
|
|
|
|
@ -835,6 +838,11 @@ class Gather(PrimitiveWithCheck):
|
|
|
|
|
[[ 2. 7.]
|
|
|
|
|
[ 4. 54.]
|
|
|
|
|
[ 2. 55.]]
|
|
|
|
|
>>> axis = 0
|
|
|
|
|
>>> output = ops.Gather()(input_params, input_indices, axis)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[3. 4. 54. 22.]
|
|
|
|
|
[2. 2. 55. 3.]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -2890,9 +2898,36 @@ class StridedSlice(PrimitiveWithInfer):
|
|
|
|
|
Starting from the beginning position, the fragment continues adding stride to the index until
|
|
|
|
|
all dimensions are not less than the ending position.
|
|
|
|
|
|
|
|
|
|
Given a `input_x[m1, m2, ..., mn]`, `begin`, `end` and `strides` will be vectors of length n.
|
|
|
|
|
|
|
|
|
|
In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask`, `shrink_axis_mask`)
|
|
|
|
|
the ith bit will correspond to the ith m.
|
|
|
|
|
|
|
|
|
|
If the ith bit of `begin_mask` is set, `begin[i]` is ignored and the fullest possible range in that dimension
|
|
|
|
|
is used instead. `end_mask` is analogous, except with the end range.
|
|
|
|
|
|
|
|
|
|
As for a 5*6*7 tensor, `x[2:,:3,:]` is equivalent to `x[2:5,0:3,0:7]`.
|
|
|
|
|
|
|
|
|
|
If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions as needed will be inserted between
|
|
|
|
|
other dimensions. Only one non-zero bit is allowed in `ellipsis_mask`.
|
|
|
|
|
|
|
|
|
|
As for a 5*6*7*8 tensor, `x[2:,...,:6]` is equivalent to `x[2:5,:,:,0:6]`.
|
|
|
|
|
`x[2:,...]` is equivalent to `x[2:5,:,:,:]`.
|
|
|
|
|
|
|
|
|
|
If the ith bit of `new_axis_mask` is set, `begin`, `end` and `strides` are ignored and a new length 1
|
|
|
|
|
dimension is added at the specified position in tthe output tensor.
|
|
|
|
|
|
|
|
|
|
As for a 5*6*7 tensor, `x[:2, newaxis, :6]` will produce a tensor with shape (2, 1, 7).
|
|
|
|
|
|
|
|
|
|
If the ith bit of `shrink_axis_mask` is set, ith size shrinks the dimension by 1, taking on the value
|
|
|
|
|
at index `begin[i]`, `end[i]` and `strides[i]` are ignored.
|
|
|
|
|
|
|
|
|
|
As for a 5*6*7 tensor, `x[:, 5, :]` will result in `shrink_axis_mask` equal to 4.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
The stride may be negative value, which causes reverse slicing.
|
|
|
|
|
The shape of `begin`, `end` and `strides` must be the same.
|
|
|
|
|
`begin` and `end` are zero-indexed. The element of `strides` must be non-zero.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
begin_mask (int): Starting index of the slice. Default: 0.
|
|
|
|
@ -3361,6 +3396,15 @@ class GatherNd(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
Using given indices to gather slices from a tensor with a specified shape.
|
|
|
|
|
|
|
|
|
|
`indices` is an K-dimensional integer tensor. Supposes it as a (K-1)-dimensional tensor and each element of it
|
|
|
|
|
defines a slice of `input_x`:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
output[(i_0, ..., i_{K-2})] = input_x[indices[(i_0, ..., i_{K-2})]]
|
|
|
|
|
|
|
|
|
|
The last dimension of `indices` can not more than the rank of `input_x`:
|
|
|
|
|
:math:`indices.shape[-1] <= input_x.rank`.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The target tensor to gather values.
|
|
|
|
|
- **indices** (Tensor) - The index tensor, with int data type.
|
|
|
|
|