improve unique op (#26537)

* add unique_v2 op

* remove unique_v2 op

* update doc
revert-26856-strategy_example2
Zhang Ting 5 years ago committed by GitHub
parent a004dfde3d
commit 0a895bc0df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,17 +24,63 @@ class UniqueOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "unique");
OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
in_dims.size(), 1,
platform::errors::InvalidArgument("The Input(X) should be 1-D Tensor, "
"But now the dims of Input(X) is %d.",
in_dims.size()));
if (!ctx->Attrs().Get<bool>("is_sorted")) {
OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique");
PADDLE_ENFORCE_EQ(in_dims.size(), 1,
platform::errors::InvalidArgument(
"The Input(X) should be 1-D Tensor, "
"But now the dims of Input(X) is %d.",
in_dims.size()));
ctx->SetOutputDim("Out", {-1});
ctx->SetOutputDim("Index", in_dims);
return;
}
bool return_index = ctx->Attrs().Get<bool>("return_index");
bool return_inverse = ctx->Attrs().Get<bool>("return_inverse");
bool return_counts = ctx->Attrs().Get<bool>("return_counts");
auto axis_vec = ctx->Attrs().Get<std::vector<int>>("axis");
if (return_index) {
OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "unique");
}
if (return_inverse) {
OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", "unique");
}
if (return_counts) {
OP_INOUT_CHECK(ctx->HasOutput("Counts"), "Output", "Counts", "unique");
}
ctx->SetOutputDim("Out", {-1});
ctx->SetOutputDim("Index", in_dims);
if (axis_vec.empty()) {
ctx->SetOutputDim("Out", {-1});
if (return_inverse) {
ctx->SetOutputDim("Index", {framework::product(in_dims)});
}
} else {
int axis = axis_vec[0];
if (axis < 0) {
axis += in_dims.size();
}
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument("The axis(%d) should be less than "
"the dimension size(%d) of x.",
axis, in_dims.size()));
auto out_dims = in_dims;
out_dims[axis] = -1;
ctx->SetOutputDim("Out", out_dims);
if (return_inverse) {
ctx->SetOutputDim("Index", {in_dims[axis]});
}
}
if (return_index) {
ctx->SetOutputDim("Indices", {-1});
}
if (return_counts) {
ctx->SetOutputDim("Counts", {-1});
}
}
protected:
@ -49,14 +95,47 @@ class UniqueOp : public framework::OperatorWithKernel {
class UniqueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input tensor. It should be a 1-D tensor.");
AddInput("X",
"Input tensor. It should be a 1-D tensor when Attr(is_sorted)"
" is fasle or a N-D tensor when Attr(is_sorted) is true.");
AddAttr<int>("dtype", "data type for output index");
AddOutput("Out", "A unique subsequence for input tensor.");
AddOutput("Index",
"An index tensor pointing to unique subsequence, which has "
"identical shape with input tensor and int64 dtype.");
"Equivalent to inverse in numpy.unique, "
"the indices for where elements in the original input ended up "
"in the returned unique tensor.");
AddOutput(
"Indices",
"The indices of the input tensor that result in the unique tensor.")
.AsDispensable();
AddOutput("Counts", "The counts for each unique element.").AsDispensable();
AddAttr<bool>("return_index",
"If True, also return the indices of the input"
" tensor that result in the unique Tensor.")
.SetDefault(false);
AddAttr<bool>(
"return_inverse",
"If True, also return the indices for where elements"
" in the original input ended up in the returned unique tensor.")
.SetDefault(false);
AddAttr<bool>("return_counts",
"If True, also return the counts for each unique element.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"axis",
"The axis to apply unique. If None, the input will be flattened.")
.SetDefault({});
AddAttr<bool>("is_sorted",
"If True, the unique elements of X are in ascending order."
"Otherwise, the unique elements are not sorted.")
.SetDefault(false);
AddComment(R"DOC(
Return a unique subsequence for 1-D input tensor, and an index tensor pointing to this unique subsequence
1. Return a unique subsequence for 1-D input tensor, and an index tensor
pointing to this unique subsequence when Attr(is_sorted) is false. This
means paddle.unique is called.
2. Returns the unique elements of X in ascending order when Attr(is_sorted)
is true. This means fluid.layers.unique is called.
)DOC");
}
};
@ -65,6 +144,8 @@ class UniqueOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(unique, ops::UniqueOp, ops::UniqueOpMaker);
REGISTER_OP_CPU_KERNEL(unique, ops::UniqueKernel<float>,
ops::UniqueKernel<double>, ops::UniqueKernel<int32_t>,
ops::UniqueKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(
unique, ops::UniqueKernel<paddle::platform::CPUDeviceContext, float>,
ops::UniqueKernel<paddle::platform::CPUDeviceContext, double>,
ops::UniqueKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::UniqueKernel<paddle::platform::CPUDeviceContext, int64_t>);

File diff suppressed because it is too large Load Diff

@ -62,6 +62,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"sync_batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
{"unique", {"Out", "Index", "Indices", "Counts"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are

@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
@ -125,5 +126,164 @@ class TestRandomGPU(TestUniqueOp):
self.check_output_with_place(place, atol=1e-5)
class TestSortedUniqueOp(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64')}
unique, indices, inverse, count = np.unique(
self.inputs['X'],
return_index=True,
return_inverse=True,
return_counts=True,
axis=None)
self.attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
"return_index": True,
"return_inverse": True,
"return_counts": True,
"axis": None,
"is_sorted": True
}
self.outputs = {
'Out': unique,
'Indices': indices,
"Index": inverse,
"Counts": count,
}
class TestUniqueOpAxisNone(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.random((4, 7, 10)).astype('float64')}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
return_inverse=True,
return_counts=True,
axis=None)
self.attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
"return_index": True,
"return_inverse": True,
"return_counts": True,
"axis": None,
"is_sorted": True
}
self.outputs = {
'Out': unique,
'Indices': indices,
"Index": inverse,
"Counts": counts,
}
class TestUniqueOpAxis1(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
return_inverse=True,
return_counts=True,
axis=1)
self.attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
"return_index": True,
"return_inverse": True,
"return_counts": True,
"axis": [1],
"is_sorted": True
}
self.outputs = {
'Out': unique,
'Indices': indices,
"Index": inverse,
"Counts": counts,
}
class TestUniqueAPI(unittest.TestCase):
def test_dygraph_api_out(self):
paddle.disable_static()
x_data = x_data = np.random.randint(0, 10, (120))
x = paddle.to_tensor(x_data)
out = paddle.unique(x)
expected_out = np.unique(x_data)
self.assertTrue((out.numpy() == expected_out).all(), True)
paddle.enable_static()
def test_dygraph_api_attr(self):
paddle.disable_static()
x_data = np.random.random((3, 5, 5)).astype("float32")
x = paddle.to_tensor(x_data)
out, index, inverse, counts = paddle.unique(
x,
return_index=True,
return_inverse=True,
return_counts=True,
axis=0)
np_out, np_index, np_inverse, np_counts = np.unique(
x_data,
return_index=True,
return_inverse=True,
return_counts=True,
axis=0)
self.assertTrue((out.numpy() == np_out).all(), True)
self.assertTrue((index.numpy() == np_index).all(), True)
self.assertTrue((inverse.numpy() == np_inverse).all(), True)
self.assertTrue((counts.numpy() == np_counts).all(), True)
paddle.enable_static()
def test_static_graph(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.data(name='x', shape=[3, 2], dtype='float64')
unique, inverse, counts = paddle.unique(
x, return_inverse=True, return_counts=True, axis=0)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
x_np = np.array([[1, 2], [3, 4], [1, 2]]).astype('float64')
result = exe.run(feed={"x": x_np},
fetch_list=[unique, inverse, counts])
np_unique, np_inverse, np_counts = np.unique(
x_np, return_inverse=True, return_counts=True, axis=0)
self.assertTrue(np.allclose(result[0], np_unique))
self.assertTrue(np.allclose(result[1], np_inverse))
self.assertTrue(np.allclose(result[2], np_counts))
class TestUniqueError(unittest.TestCase):
def test_input_dtype(self):
def test_x_dtype():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.data(name='x', shape=[10, 10], dtype='float16')
result = paddle.unique(x)
self.assertRaises(TypeError, test_x_dtype)
def test_attr(self):
x = paddle.data(name='x', shape=[10, 10], dtype='float64')
def test_return_index():
result = paddle.unique(x, return_index=0)
self.assertRaises(TypeError, test_return_index)
def test_return_inverse():
result = paddle.unique(x, return_inverse='s')
self.assertRaises(TypeError, test_return_inverse)
def test_return_counts():
result = paddle.unique(x, return_counts=3)
self.assertRaises(TypeError, test_return_counts)
def test_axis():
result = paddle.unique(x, axis='12')
self.assertRaises(TypeError, test_axis)
if __name__ == "__main__":
unittest.main()

@ -27,7 +27,6 @@ from ..fluid.layers import expand_as #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS
from ..fluid.layers import transpose #DEFINE_ALIAS
from ..fluid.layers import unique #DEFINE_ALIAS
from ..fluid.layers import unstack #DEFINE_ALIAS
from ..fluid.layers import scatter_nd_add #DEFINE_ALIAS
@ -608,6 +607,126 @@ def squeeze(x, axis=None, name=None):
return layers.squeeze(x, axis, name)
def unique(x,
return_index=False,
return_inverse=False,
return_counts=False,
axis=None,
name=None):
"""
Returns the unique elements of `x` in ascending order.
Args:
x(Tensor): The input tensor, it's data type should be float32, float64, int32, int64.
return_index(bool, optional): If True, also return the indices of the input tensor that
result in the unique Tensor.
return_inverse(bool, optional): If True, also return the indices for where elements in
the original input ended up in the returned unique tensor.
return_counts(bool, optional): If True, also return the counts for each unique element.
axis(int, optional): The axis to apply unique. If None, the input will be flattened.
Default: None.
name(str, optional): Name for the operation. For more information, please refer to
:ref:`api_guide_Name`. Default: None.
Returns:
tuple: (out, indices, inverse, counts). `out` is the unique tensor for `x`. `indices` is \
provided only if `return_index` is True. `inverse` is provided only if `return_inverse` \
is True. `counts` is provided only if `return_counts` is True.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
x_data = np.array([2, 3, 3, 1, 5, 3])
x = paddle.to_tensor(x_data)
unique = paddle.unique(x)
np_unique = unique.numpy() # [1 2 3 5]
_, indices, inverse, counts = paddle.unique(x, return_index=True, return_inverse=True, return_counts=True)
np_indices = indices.numpy() # [3 0 1 4]
np_inverse = inverse.numpy() # [1 2 2 0 3 2]
np_counts = counts.numpy() # [1 1 3 1]
x_data = np.array([[2, 1, 3], [3, 0, 1], [2, 1, 3]])
unique = paddle.unique(x)
np_unique = unique.numpy() # [0 1 2 3]
unique = paddle.unique(x, axis=0)
np_unique = unique.numpy()
# [[2 1 3]
# [3 0 1]]
"""
if axis is None:
axis = []
else:
axis = [axis]
if in_dygraph_mode():
out, inverse, indices, counts = core.ops.unique(
x, 'dtype',
convert_np_dtype_to_dtype_('int32'), 'return_index', return_index,
'return_inverse', return_inverse, 'return_counts', return_counts,
'axis', axis, "is_sorted", True)
outs = [out]
if return_index:
outs.append(indices)
if return_inverse:
outs.append(inverse)
if return_counts:
outs.append(counts)
if len(outs) == 1:
return outs[0]
return tuple(outs)
check_variable_and_dtype(x, "input",
['float32', 'float64', 'int32', 'int64'], 'unique')
check_type(return_index, 'return_index', bool, 'unique')
check_type(return_inverse, 'return_inverse', bool, 'unique')
check_type(return_counts, 'return_counts', bool, 'unique')
if len(axis) != 0:
check_type(axis[0], 'axis', int, 'unique')
helper = LayerHelper('unique', **locals())
attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
"return_index": return_index,
"return_inverse": return_inverse,
"return_counts": return_counts,
"axis": axis,
"is_sorted": True
}
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
inverse = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True)
outputs = {"Out": out, "Index": inverse}
outs = [out]
if return_index:
indices = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True)
outputs["Indices"] = indices
outs.append(indices)
if return_inverse:
outs.append(inverse)
if return_counts:
counts = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True)
outputs["Counts"] = counts
outs.append(counts)
helper.append_op(
type="unique", inputs={"X": x}, attrs=attrs, outputs=outputs)
if len(outs) == 1:
return outs[0]
return tuple(outs)
def unsqueeze(x, axis, name=None):
"""
:alias_main: paddle.unsqueeze

Loading…
Cancel
Save