|
|
|
@ -23,137 +23,6 @@ from ..primitive import PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
from ..operations.math_ops import _infer_shape_reduce
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StridedSliceAICPU(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
|
|
|
|
|
Extracts a strided slice of a tensor.
|
|
|
|
|
|
|
|
|
|
Given an input tensor, this operation inserts a dimension of length 1 at the dimension.
|
|
|
|
|
This operation extracts a fragment of size (end-begin)/stride from the given
|
|
|
|
|
'input_tensor'. Starting from the position specified by the begin, the fragment
|
|
|
|
|
continues adding stride to the index until all dimensions are not less than end.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
The stride may be negative value, which causes reverse slicing.
|
|
|
|
|
The shape of `begin`, `end` and `strides` must be the same.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
begin_mask (int): Starting index of the slice. Default: 0.
|
|
|
|
|
end_mask (int): Ending index of the slice. Default: 0.
|
|
|
|
|
ellipsis_mask (int): An int mask. Default: 0.
|
|
|
|
|
new_axis_mask (int): An int mask. Default: 0.
|
|
|
|
|
shrink_axis_mask (int): An int mask. Default: 0.
|
|
|
|
|
Currently all the masks are not in used. Use default 0 only.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The input Tensor.
|
|
|
|
|
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
|
|
|
|
constant value is allowed.
|
|
|
|
|
- **end** (tuple[int]) - A tuple or which represents the maximum location where to stop.
|
|
|
|
|
Only constant value is allowed.
|
|
|
|
|
- **strides** (tuple[int]) - A tuple which represents the stride continuously added
|
|
|
|
|
before reach the maximum location. Only constant value is allowed.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor.
|
|
|
|
|
Explain with the following example.
|
|
|
|
|
- In the 0th dim, begin is 1, end is 2, and strides is 1,
|
|
|
|
|
because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
|
|
|
|
|
Thus, return the element with :math:`index = 1` in 0th dim, i.e., [[3, 3, 3], [4, 4, 4]].
|
|
|
|
|
- In the 1st dim, similarly, the interval is :math:`[0,1)`.
|
|
|
|
|
Based on the return value of the 0th dim, return the element with :math:`index = 0`,
|
|
|
|
|
i.e., [3, 3, 3].
|
|
|
|
|
- In the 2nd dim, similarly, the interval is :math:`[0,3)`.
|
|
|
|
|
Based on the return value of the 1st dim, return the element with :math:`index = 0,1,2`,
|
|
|
|
|
i.e., [3, 3, 3].
|
|
|
|
|
- Finally, the output is [3, 3, 3].
|
|
|
|
|
|
|
|
|
|
Examples
|
|
|
|
|
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
|
|
|
|
|
>>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
|
|
|
|
|
>>> slice = P.StridedSliceAICPU()
|
|
|
|
|
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 2))
|
|
|
|
|
>>> output.shape
|
|
|
|
|
(1, 1, 2)
|
|
|
|
|
>>> output
|
|
|
|
|
[[[3, 3]]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self,
|
|
|
|
|
begin_mask=0,
|
|
|
|
|
end_mask=0,
|
|
|
|
|
ellipsis_mask=0,
|
|
|
|
|
new_axis_mask=0,
|
|
|
|
|
shrink_axis_mask=0):
|
|
|
|
|
"""Initialize StrideSlice"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
|
|
|
|
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
|
|
|
|
|
validator.check_value_type('end_mask', end_mask, [int], self.name)
|
|
|
|
|
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
|
|
|
|
|
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
|
|
|
|
|
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, begin, end, strides):
|
|
|
|
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
|
|
|
|
validator.check_value_type("begin", begin_v, [tuple], self.name)
|
|
|
|
|
validator.check_value_type("end", end_v, [tuple], self.name)
|
|
|
|
|
validator.check_value_type("strides", strides_v, [tuple], self.name)
|
|
|
|
|
|
|
|
|
|
x_shape = x['shape']
|
|
|
|
|
x_shp_len = len(x_shape)
|
|
|
|
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
|
|
|
|
|
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and "
|
|
|
|
|
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.")
|
|
|
|
|
|
|
|
|
|
ret_shape = []
|
|
|
|
|
append_dimensions = []
|
|
|
|
|
shrink_pos = bin(self.shrink_axis_mask)[::-1]
|
|
|
|
|
new_pos = bin(self.new_axis_mask)[::-1]
|
|
|
|
|
for i in range(x_shp_len):
|
|
|
|
|
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'
|
|
|
|
|
if i < (len(new_pos) - 2) and new_pos[i] == '1':
|
|
|
|
|
ret_shape.append(1)
|
|
|
|
|
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)])
|
|
|
|
|
continue
|
|
|
|
|
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1':
|
|
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name)
|
|
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
begin_idx = begin_v[i]
|
|
|
|
|
end_idx = end_v[i]
|
|
|
|
|
strides_idx = strides_v[i]
|
|
|
|
|
if self.begin_mask:
|
|
|
|
|
begin_idx = 0
|
|
|
|
|
if self.end_mask:
|
|
|
|
|
end_idx = x_shape[i]
|
|
|
|
|
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name)
|
|
|
|
|
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name)
|
|
|
|
|
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name)
|
|
|
|
|
if strides_idx > 0:
|
|
|
|
|
# If sliced forward , end_idx >= begin_idx
|
|
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE)
|
|
|
|
|
if begin_idx < 0 < end_idx:
|
|
|
|
|
# Turn negative begin_idx into positive values
|
|
|
|
|
begin_idx = x_shape[i] + begin_idx
|
|
|
|
|
num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx
|
|
|
|
|
else:
|
|
|
|
|
# If sliced backwards, end_idx <= begin_idx
|
|
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE)
|
|
|
|
|
if end_idx < 0 < begin_idx:
|
|
|
|
|
# Turn negative end_idx into positive values
|
|
|
|
|
end_idx = x_shape[i] + end_idx
|
|
|
|
|
num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx
|
|
|
|
|
|
|
|
|
|
ret_shape.append(num_elems)
|
|
|
|
|
if append_dimensions:
|
|
|
|
|
ret_shape += append_dimensions[::-1]
|
|
|
|
|
return {'shape': ret_shape,
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExtractImagePatches(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Extracts patches from images.
|
|
|
|
|