|
|
|
@ -104,23 +104,24 @@ def flip(input, dims, name=None):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def roll(input, shifts, dims=None):
|
|
|
|
|
def roll(x, shifts, axis=None, name=None):
|
|
|
|
|
"""
|
|
|
|
|
:alias_main: paddle.roll
|
|
|
|
|
:alias: paddle.roll,paddle.tensor.roll,paddle.tensor.manipulation.roll
|
|
|
|
|
|
|
|
|
|
Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond
|
|
|
|
|
the last position are re-introduced at the first position. If a dimension is not specified,
|
|
|
|
|
Roll the `x` tensor along the given axis(axes). With specific 'shifts', Elements that
|
|
|
|
|
roll beyond the last position are re-introduced at the first according to 'shifts'.
|
|
|
|
|
If a axis is not specified,
|
|
|
|
|
the tensor will be flattened before rolling and then restored to the original shape.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input (Variable): The input tensor variable.
|
|
|
|
|
x (Variable): The x tensor variable as input.
|
|
|
|
|
shifts (int|list|tuple): The number of places by which the elements
|
|
|
|
|
of the `input` tensor are shifted.
|
|
|
|
|
dims (int|list|tuple|None): Dimentions along which to roll.
|
|
|
|
|
of the `x` tensor are shifted.
|
|
|
|
|
axis (int|list|tuple|None): axis(axes) along which to roll.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable: A Tensor with same data type as `input`.
|
|
|
|
|
Variable: A Tensor with same data type as `x`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
@ -131,48 +132,56 @@ def roll(input, shifts, dims=None):
|
|
|
|
|
data = np.array([[1.0, 2.0, 3.0],
|
|
|
|
|
[4.0, 5.0, 6.0],
|
|
|
|
|
[7.0, 8.0, 9.0]])
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
x = fluid.dygraph.to_variable(data)
|
|
|
|
|
out_z1 = paddle.roll(x, shifts=1)
|
|
|
|
|
print(out_z1.numpy())
|
|
|
|
|
#[[9. 1. 2.]
|
|
|
|
|
# [3. 4. 5.]
|
|
|
|
|
# [6. 7. 8.]]
|
|
|
|
|
out_z2 = paddle.roll(x, shifts=1, dims=0)
|
|
|
|
|
print(out_z2.numpy())
|
|
|
|
|
#[[7. 8. 9.]
|
|
|
|
|
# [1. 2. 3.]
|
|
|
|
|
# [4. 5. 6.]]
|
|
|
|
|
paddle.enable_imperative()
|
|
|
|
|
x = paddle.imperative.to_variable(data)
|
|
|
|
|
out_z1 = paddle.roll(x, shifts=1)
|
|
|
|
|
print(out_z1.numpy())
|
|
|
|
|
#[[9. 1. 2.]
|
|
|
|
|
# [3. 4. 5.]
|
|
|
|
|
# [6. 7. 8.]]
|
|
|
|
|
out_z2 = paddle.roll(x, shifts=1, axis=0)
|
|
|
|
|
print(out_z2.numpy())
|
|
|
|
|
#[[7. 8. 9.]
|
|
|
|
|
# [1. 2. 3.]
|
|
|
|
|
# [4. 5. 6.]]
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper("roll", **locals())
|
|
|
|
|
origin_shape = input.shape
|
|
|
|
|
origin_shape = x.shape
|
|
|
|
|
if type(shifts) == int:
|
|
|
|
|
shifts = [shifts]
|
|
|
|
|
if type(dims) == int:
|
|
|
|
|
dims = [dims]
|
|
|
|
|
|
|
|
|
|
if dims:
|
|
|
|
|
check_type(dims, 'dims', (list, tuple), 'roll')
|
|
|
|
|
if type(axis) == int:
|
|
|
|
|
axis = [axis]
|
|
|
|
|
|
|
|
|
|
len_origin_shape = len(origin_shape)
|
|
|
|
|
if axis:
|
|
|
|
|
for i in range(len(axis)):
|
|
|
|
|
if axis[i] >= len_origin_shape or axis[i] < -len_origin_shape:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"axis is out of range, it should be in range [{}, {}), but received {}".
|
|
|
|
|
format(-len_origin_shape, len_origin_shape, axis))
|
|
|
|
|
|
|
|
|
|
if axis:
|
|
|
|
|
check_type(axis, 'axis', (list, tuple), 'roll')
|
|
|
|
|
check_type(shifts, 'shifts', (list, tuple), 'roll')
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
if dims is None:
|
|
|
|
|
input = core.ops.reshape(input, 'shape', [-1, 1])
|
|
|
|
|
dims = [0]
|
|
|
|
|
out = core.ops.roll(input, 'dims', dims, 'shifts', shifts)
|
|
|
|
|
if axis is None:
|
|
|
|
|
x = core.ops.reshape(x, 'shape', [-1, 1])
|
|
|
|
|
axis = [0]
|
|
|
|
|
out = core.ops.roll(x, 'axis', axis, 'shifts', shifts)
|
|
|
|
|
return core.ops.reshape(out, 'shape', origin_shape)
|
|
|
|
|
|
|
|
|
|
out = helper.create_variable_for_type_inference(input.dtype)
|
|
|
|
|
out = helper.create_variable_for_type_inference(x.dtype)
|
|
|
|
|
|
|
|
|
|
if dims is None:
|
|
|
|
|
input = reshape(input, shape=[-1, 1])
|
|
|
|
|
dims = [0]
|
|
|
|
|
if axis is None:
|
|
|
|
|
x = reshape(x, shape=[-1, 1])
|
|
|
|
|
axis = [0]
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='roll',
|
|
|
|
|
inputs={'X': input},
|
|
|
|
|
inputs={'X': x},
|
|
|
|
|
outputs={'Out': out},
|
|
|
|
|
attrs={'dims': dims,
|
|
|
|
|
attrs={'axis': axis,
|
|
|
|
|
'shifts': shifts})
|
|
|
|
|
out = reshape(out, shape=origin_shape, inplace=True)
|
|
|
|
|
return out
|
|
|
|
|