|
|
|
@ -791,29 +791,29 @@ def gather(x, index, axis=None, name=None):
|
|
|
|
|
|
|
|
|
|
def unbind(input, axis=0):
|
|
|
|
|
"""
|
|
|
|
|
:alias_main: paddle.tensor.unbind
|
|
|
|
|
:alias: paddle.tensor.unbind,paddle.tensor.manipulation.unbind
|
|
|
|
|
|
|
|
|
|
Removes a tensor dimension, then split the input tensor into multiple sub-Tensors.
|
|
|
|
|
Args:
|
|
|
|
|
input (Variable): The input variable which is an N-D Tensor, data type being float32, float64, int32 or int64.
|
|
|
|
|
|
|
|
|
|
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind. If :math:`axis < 0`, the
|
|
|
|
|
dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
|
|
|
|
|
Args:
|
|
|
|
|
input (Tensor): The input variable which is an N-D Tensor, data type being float32, float64, int32 or int64.
|
|
|
|
|
axis (int32|int64, optional): A scalar with type ``int32|int64`` shape [1]. The dimension along which to unbind.
|
|
|
|
|
If :math:`axis < 0`, the dimension to unbind along is :math:`rank(input) + axis`. Default is 0.
|
|
|
|
|
Returns:
|
|
|
|
|
list(Variable): The list of segmented Tensor variables.
|
|
|
|
|
list(Tensor): The list of segmented Tensor variables.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import numpy as np
|
|
|
|
|
# input is a variable which shape is [3, 4, 5]
|
|
|
|
|
input = paddle.fluid.data(
|
|
|
|
|
name="input", shape=[3, 4, 5], dtype="float32")
|
|
|
|
|
[x0, x1, x2] = paddle.tensor.unbind(input, axis=0)
|
|
|
|
|
np_input = np.random.rand(3, 4, 5).astype('float32')
|
|
|
|
|
input = paddle.to_tensor(np_input)
|
|
|
|
|
[x0, x1, x2] = paddle.unbind(input, axis=0)
|
|
|
|
|
# x0.shape [4, 5]
|
|
|
|
|
# x1.shape [4, 5]
|
|
|
|
|
# x2.shape [4, 5]
|
|
|
|
|
[x0, x1, x2, x3] = paddle.tensor.unbind(input, axis=1)
|
|
|
|
|
[x0, x1, x2, x3] = paddle.unbind(input, axis=1)
|
|
|
|
|
# x0.shape [3, 5]
|
|
|
|
|
# x1.shape [3, 5]
|
|
|
|
|
# x2.shape [3, 5]
|
|
|
|
@ -837,6 +837,8 @@ def unbind(input, axis=0):
|
|
|
|
|
helper.create_variable_for_type_inference(dtype=helper.input_dtype())
|
|
|
|
|
for i in range(num)
|
|
|
|
|
]
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
return core.ops.unbind(input, num, 'axis', axis)
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type="unbind",
|
|
|
|
|