|
|
|
@ -14,7 +14,7 @@ __all__ = [
|
|
|
|
|
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
|
|
|
|
|
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
|
|
|
|
|
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
|
|
|
|
|
'sequence_first_step', 'sequence_last_step', 'dropout'
|
|
|
|
|
'sequence_first_step', 'sequence_last_step', 'dropout', 'split'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1504,3 +1504,63 @@ def reduce_min(input, dim=None, keep_dim=False):
|
|
|
|
|
'reduce_all': True if dim == None else False
|
|
|
|
|
})
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split(input, num_or_sections, dim=-1):
|
|
|
|
|
"""
|
|
|
|
|
Splits the tensor into multiple sub-tensors.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input (Variable): The input variable which is a Tensor or LoDTensor.
|
|
|
|
|
num_or_sections (int|list): If :attr:`num_or_sections` is an integer,
|
|
|
|
|
then the integer indicates the number of equal sized sub-tensors
|
|
|
|
|
that the tensor will be divided into. If :attr:`num_or_sections`
|
|
|
|
|
is a list of integers, the length of list indicates the number of
|
|
|
|
|
sub-tensors and the integers indicate the sizes of sub-tensors'
|
|
|
|
|
:attr:`dim` dimension orderly.
|
|
|
|
|
dim (int): The dimension along which to split. If :math:`dim < 0`, the
|
|
|
|
|
dimension to split along is :math:`rank(input) + dim`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List: The list of segmented tensor variables.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
# x is a Tensor variable with shape [3, 9, 5]:
|
|
|
|
|
x0, x1, x2 = fluid.layers.split(x, num_or_sections=3, dim=1)
|
|
|
|
|
x0.shape # [3, 3, 5]
|
|
|
|
|
x1.shape # [3, 3, 5]
|
|
|
|
|
x2.shape # [3, 3, 5]
|
|
|
|
|
x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1)
|
|
|
|
|
x0.shape # [3, 2, 5]
|
|
|
|
|
x1.shape # [3, 3, 5]
|
|
|
|
|
x2.shape # [3, 4, 5]
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper('split', **locals())
|
|
|
|
|
input_shape = input.shape
|
|
|
|
|
dim = (len(input_shape) + dim) if dim < 0 else dim
|
|
|
|
|
if isinstance(num_or_sections, int):
|
|
|
|
|
assert num_or_sections > 1, 'num_or_sections must be more than 1.'
|
|
|
|
|
assert input_shape[
|
|
|
|
|
dim] % num_or_sections == 0, 'num_or_sections must evenly divide input.shape[dim].'
|
|
|
|
|
num = num_or_sections
|
|
|
|
|
else:
|
|
|
|
|
assert len(num_or_sections) < input_shape[
|
|
|
|
|
dim], 'len(num_or_sections) must not be more than input.shape[dim].'
|
|
|
|
|
num = len(num_or_sections)
|
|
|
|
|
outs = [
|
|
|
|
|
helper.create_tmp_variable(dtype=helper.input_dtype())
|
|
|
|
|
for i in range(num)
|
|
|
|
|
]
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='split',
|
|
|
|
|
inputs={'X': input},
|
|
|
|
|
outputs={'Out': outs},
|
|
|
|
|
attrs={
|
|
|
|
|
'num': num_or_sections if isinstance(num_or_sections, int) else 0,
|
|
|
|
|
'sections': num_or_sections
|
|
|
|
|
if isinstance(num_or_sections, list) else [],
|
|
|
|
|
'axis': dim
|
|
|
|
|
})
|
|
|
|
|
return outs
|
|
|
|
|