|
|
|
@ -33,6 +33,7 @@ __all__ = [
|
|
|
|
|
'fill_constant',
|
|
|
|
|
'argmin',
|
|
|
|
|
'argmax',
|
|
|
|
|
'argsort',
|
|
|
|
|
'ones',
|
|
|
|
|
'zeros',
|
|
|
|
|
'reverse',
|
|
|
|
@ -438,6 +439,56 @@ def argmax(x, axis=0):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def argsort(input, axis=-1):
|
|
|
|
|
"""
|
|
|
|
|
Performs sorting on the input Variable along the given axis, and outputs
|
|
|
|
|
sorted data Varibale and its corresponding index Variable with the same
|
|
|
|
|
shape as :attr:`input`.
|
|
|
|
|
|
|
|
|
|
.. code-block:: text
|
|
|
|
|
|
|
|
|
|
For example, the given axis is -1 and the input Variable
|
|
|
|
|
|
|
|
|
|
input = [[0.15849551, 0.45865775, 0.8563702 ],
|
|
|
|
|
[0.12070083, 0.28766365, 0.18776911]],
|
|
|
|
|
|
|
|
|
|
after argsort, the sorted Vairable becomes
|
|
|
|
|
|
|
|
|
|
out = [[0.15849551, 0.45865775, 0.8563702 ],
|
|
|
|
|
[0.12070083, 0.18776911, 0.28766365]],
|
|
|
|
|
|
|
|
|
|
and the sorted indices along the given axis turn outs to be
|
|
|
|
|
|
|
|
|
|
indices = [[0, 1, 2],
|
|
|
|
|
[0, 2, 1]]
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input(Variable): The input Variable for sorting.
|
|
|
|
|
axis(int): The axis along which to sort the input Variable. When
|
|
|
|
|
:attr:`axis` < 0, the actual axis will be :attr:`axis` +
|
|
|
|
|
rank(:attr:`input`). Default -1, the last dimension.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple: A tuple of sorted data Variable and the sorted indices.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
input = fluid.layers.data(data=[2, 3])
|
|
|
|
|
out, indices = fluid.layers.argsort(input, axis=0)
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper("argsort", **locals())
|
|
|
|
|
out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True)
|
|
|
|
|
ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='argsort',
|
|
|
|
|
inputs={'X': input},
|
|
|
|
|
outputs={'Out': out,
|
|
|
|
|
'Indics': ids},
|
|
|
|
|
attts={'axis': axis})
|
|
|
|
|
return out, ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ones(shape, dtype, force_cpu=False):
|
|
|
|
|
"""
|
|
|
|
|
**ones**
|
|
|
|
|