From 4760f2851ef37186c836a1cf46fea87f5a806fb2 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 4 Jun 2018 10:31:32 -0700 Subject: [PATCH 1/8] Add the argsort operator --- paddle/fluid/operators/argsort_op.cc | 83 ++++++++++++++++++ paddle/fluid/operators/argsort_op.h | 86 +++++++++++++++++++ .../fluid/tests/unittests/test_argsort_op.py | 49 +++++++++++ 3 files changed, 218 insertions(+) create mode 100644 paddle/fluid/operators/argsort_op.cc create mode 100644 paddle/fluid/operators/argsort_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_argsort_op.py diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc new file mode 100644 index 0000000000..aead4e2e00 --- /dev/null +++ b/paddle/fluid/operators/argsort_op.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/argsort_op.h" + +namespace paddle { +namespace operators { + +class ArgsortOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ArgsortOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ArgsortOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Indices"), + "Output(Indices) of ArgsortOp should not be null."); + + auto in_dims = ctx->GetInputDim("X"); + int axis = static_cast(ctx->Attrs().Get("axis")); + + auto num_dims = in_dims.size(); + PADDLE_ENFORCE(axis < num_dims, + "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " + "dimension %d.", + axis, num_dims); + PADDLE_ENFORCE(axis >= 0 || axis == -1, + "Attr(axis) %d of ArgsortOp must be nonnegative or equal to " + "-1.", + axis); + + ctx->SetOutputDim("Out", in_dims); + ctx->SetOutputDim("Indices", in_dims); + ctx->ShareLoD("X", "Out"); + ctx->ShareLoD("X", "Indices"); + } +}; + +class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input of Argsort op."); + AddOutput("Out", "(Tensor) The sorted tensor of Argsort op."); + AddOutput("Indices", + "(Tensor) The indices of a tensor giving the sorted order."); + AddComment(R"DOC( +Argsort operator + +Performs sorting on the input tensor along the given axis and outputs two +tensors, Output(Out) and Output(Indices). They reserve the same shape +with Input(X), and Output(Out) represents the sorted tensor while +Output(Indices) gives the sorted order along the given axis Attr(axis). + + )DOC"); + AddAttr("axis", + "(int, default -1) The axis along which to sort the tensor, " + "default -1, the last dimension.") + .SetDefault(-1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(argsort, + ops::ArgsortKernel, + ops::ArgsortKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h new file mode 100644 index 0000000000..a9fe22c4ce --- /dev/null +++ b/paddle/fluid/operators/argsort_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class ArgsortKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = static_cast(ctx.Attr("axis")); + + auto in_dims = input->dims(); + axis = (axis == -1) ? (in_dims.size() - 1) : axis; + + const T* in_data = input->data(); + T* out_data = output->mutable_data(ctx.GetPlace()); + int64_t* idx_data = indices->mutable_data(ctx.GetPlace()); + + int64_t part_dims_prod = input->numel() / in_dims[axis]; + for (int64_t i = 0; i < part_dims_prod; ++i) { + int64_t idx = i; + std::vector idx_vec(in_dims.size(), 0); + for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { + if (dim != axis) { + idx_vec[dim] = idx % in_dims[dim]; + idx /= in_dims[dim]; + } + } + std::vector> in_vec; + std::vector org_index_vec(in_dims[axis], 0); + for (int64_t j = 0; j < in_dims[axis]; ++j) { + idx_vec[axis] = j; + int64_t index = idx_vec[0]; + for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { + index = index * in_dims[dim + 1] + idx_vec[dim + 1]; + } + in_vec.push_back(std::pair(in_data[index], j)); + org_index_vec[j] = index; + } + + std::sort( + in_vec.begin(), in_vec.end(), + [](const std::pair& v1, const std::pair& v2) { + return v1.first < v2.first; + }); + + for (size_t j = 0; j < org_index_vec.size(); ++j) { + int64_t index = org_index_vec[j]; + out_data[index] = in_vec[j].first; + idx_data[index] = in_vec[j].second; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py new file mode 100644 index 0000000000..6995621ba8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestArgsortOp(OpTest): + def setUp(self): + self.init_axis() + x = np.random.random((2, 3, 4, 5)).astype("float32") + self.indices = np.argsort(x, kind='quicksort', axis=self.axis) + self.out = np.sort(x, kind='quicksort', axis=self.axis) + self.op_type = "argsort" + self.inputs = {'X': x} + self.attrs = {'axis': self.axis} + self.outputs = {'Indices': self.indices, 'Out': self.out} + + def init_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output() + + +class TestArgsortOpAxis0(TestArgsortOp): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpAxis1(TestArgsortOp): + def init_axis(self): + self.axis = 1 + + +if __name__ == "__main__": + unittest.main() From 2c2120c8a269080a883ad4b1eb8022d71f3d2cf8 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 4 Jun 2018 20:25:25 -0700 Subject: [PATCH 2/8] Remove redundant code --- paddle/fluid/operators/argsort_op.h | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index a9fe22c4ce..51d2b89f94 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -14,28 +14,20 @@ limitations under the License. */ #pragma once #include -#include #include #include -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; - -template -using EigenMatrix = framework::EigenMatrix; - template class ArgsortKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - auto* indices = ctx.Output("Indices"); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); int axis = static_cast(ctx.Attr("axis")); auto in_dims = input->dims(); From 6ee22c4f71173ada588772ee2b3e2828320d9e17 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 12 Jun 2018 01:43:08 -0700 Subject: [PATCH 3/8] Add gpu kernel for argsort op --- paddle/fluid/operators/argsort_op.cu | 140 +++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 paddle/fluid/operators/argsort_op.cu diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu new file mode 100644 index 0000000000..d1fbd28e1b --- /dev/null +++ b/paddle/fluid/operators/argsort_op.cu @@ -0,0 +1,140 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, + T* med_out) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + med_out[trg_idx[index]] = in[index]; + } +} + +template +__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out, + int64_t* med_ids) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < groups) { + thrust::sort_by_key(thrust::device, med_out + index * axis_dim, + med_out + axis_dim * (1 + index), + med_ids + index * axis_dim); + } +} + +template +__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids, + const int64_t* trg_idx, int64_t n, T* out, + int64_t* indices) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + out[index] = med_out[trg_idx[index]]; + indices[index] = med_ids[trg_idx[index]]; + } +} + +template +class ArgsortOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = ctx.Attr("axis"); + + auto in_dims = input->dims(); + axis = (axis == -1) ? (in_dims.size() - 1) : axis; + + const T* in_data = input->data(); + T* out_data = output->mutable_data(ctx.GetPlace()); + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); + + int64_t numel = input->numel(); + int64_t groups = numel / in_dims[axis]; + + // Mediate tensor for sorting + Tensor mediate_output; + T* med_out_data = + mediate_output.mutable_data(input->dims(), ctx.GetPlace()); + + // The target index of each elemement in mediate tensor + std::vector target_idx(numel, 0); + // To record the index along the given axis for the data in mediate tensor + std::vector mediate_indices(numel, 0); + std::vector in_dims_out_axis = vectorize(in_dims); + in_dims_out_axis.erase(in_dims_out_axis.begin() + axis); + for (int64_t index = 0; index < numel; ++index) { + int64_t tmp = index; + int64_t pos_in_axis = 0; + std::vector shape; + for (int64_t j = in_dims.size() - 1; j >= 0; --j) { + if (j != axis) { + shape.push_back(tmp % in_dims[j]); + } else { + pos_in_axis = tmp % in_dims[j]; + } + tmp /= in_dims[j]; + } + std::reverse(shape.begin(), shape.end()); + int64_t group = (shape.size() > 0) ? shape[0] : 0; + for (size_t j = 0; j < shape.size() - 1; ++j) { + group = group * in_dims_out_axis[j + 1] + shape[j + 1]; + } + + target_idx[index] = group * in_dims[axis] + pos_in_axis; + mediate_indices[target_idx[index]] = pos_in_axis; + } + + thrust::device_vector med_ids_dev(mediate_indices.begin(), + mediate_indices.end()); + int64_t* med_ids_data = thrust::raw_pointer_cast(med_ids_dev.data()); + thrust::device_vector trg_idx_dev(target_idx.begin(), + target_idx.end()); + int64_t* trg_idx = thrust::raw_pointer_cast(trg_idx_dev.data()); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + auto num_threads = PADDLE_CUDA_NUM_THREADS; + + PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_data, trg_idx, numel, med_out_data); + + Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_dims[axis], groups, med_out_data, med_ids_data); + + PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0, + stream>>>(med_out_data, med_ids_data, trg_idx, numel, + out_data, ids_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel, + paddle::operators::ArgsortOpCUDAKernel); From 42645ff779dcaea3f68ccdc2e183199857c2e18e Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 12 Jun 2018 05:52:47 -0700 Subject: [PATCH 4/8] Compute target index on gpu --- paddle/fluid/operators/argsort_op.cc | 2 +- paddle/fluid/operators/argsort_op.cu | 89 ++++++++++++++++------------ 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index aead4e2e00..2943d409a2 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -30,7 +30,7 @@ class ArgsortOp : public framework::OperatorWithKernel { "Output(Indices) of ArgsortOp should not be null."); auto in_dims = ctx->GetInputDim("X"); - int axis = static_cast(ctx->Attrs().Get("axis")); + int axis = ctx->Attrs().Get("axis"); auto num_dims = in_dims.size(); PADDLE_ENFORCE(axis < num_dims, diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index d1fbd28e1b..eac18ea3a0 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -26,6 +26,42 @@ namespace operators { using Tensor = framework::Tensor; using platform::PADDLE_CUDA_NUM_THREADS; +__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, + int axis, int64_t n, int64_t* trg_idx, + int64_t* med_ids) { + int64_t index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + int64_t* shape_out_axis = new int64_t[dims_size - 1]; + int64_t* dims_out_axis = new int64_t[dims_size - 1]; + int64_t tmp = index; + int64_t pos_in_axis = 0; + int64_t i = dims_size - 2; + int64_t dim_axis = 0; + for (int64_t j = dims_size - 1; j >= 0; --j) { + int64_t dim = in_dims[j]; + if (j != axis) { + shape_out_axis[i] = tmp % dim; + dims_out_axis[i] = dim; + i--; + } else { + dim_axis = dim; + pos_in_axis = tmp % dim_axis; + } + tmp /= dim; + } + int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0; + for (int64_t j = 0; j < dims_size - 2; ++j) { + group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1]; + } + + int64_t traget_idx = group * dim_axis + pos_in_axis; + trg_idx[index] = traget_idx; + med_ids[traget_idx] = pos_in_axis; + delete[] shape_out_axis; + delete[] dims_out_axis; + } +} + template __global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, T* med_out) { @@ -76,50 +112,27 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { int64_t numel = input->numel(); int64_t groups = numel / in_dims[axis]; - // Mediate tensor for sorting - Tensor mediate_output; + std::vector in_dims_vec = vectorize(in_dims); + thrust::device_vector in_dims_dev(in_dims_vec.begin(), + in_dims_vec.end()); + int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); + // Mediate tensor for sorting data and indices + Tensor mediate_output, mediate_indices; T* med_out_data = mediate_output.mutable_data(input->dims(), ctx.GetPlace()); - - // The target index of each elemement in mediate tensor - std::vector target_idx(numel, 0); - // To record the index along the given axis for the data in mediate tensor - std::vector mediate_indices(numel, 0); - std::vector in_dims_out_axis = vectorize(in_dims); - in_dims_out_axis.erase(in_dims_out_axis.begin() + axis); - for (int64_t index = 0; index < numel; ++index) { - int64_t tmp = index; - int64_t pos_in_axis = 0; - std::vector shape; - for (int64_t j = in_dims.size() - 1; j >= 0; --j) { - if (j != axis) { - shape.push_back(tmp % in_dims[j]); - } else { - pos_in_axis = tmp % in_dims[j]; - } - tmp /= in_dims[j]; - } - std::reverse(shape.begin(), shape.end()); - int64_t group = (shape.size() > 0) ? shape[0] : 0; - for (size_t j = 0; j < shape.size() - 1; ++j) { - group = group * in_dims_out_axis[j + 1] + shape[j + 1]; - } - - target_idx[index] = group * in_dims[axis] + pos_in_axis; - mediate_indices[target_idx[index]] = pos_in_axis; - } - - thrust::device_vector med_ids_dev(mediate_indices.begin(), - mediate_indices.end()); - int64_t* med_ids_data = thrust::raw_pointer_cast(med_ids_dev.data()); - thrust::device_vector trg_idx_dev(target_idx.begin(), - target_idx.end()); - int64_t* trg_idx = thrust::raw_pointer_cast(trg_idx_dev.data()); + int64_t* med_ids_data = + mediate_indices.mutable_data(in_dims, ctx.GetPlace()); + // Target index of each element along the given axis in the mediate tensors + Tensor trg_idx_t; + int64_t* trg_idx = trg_idx_t.mutable_data(in_dims, ctx.GetPlace()); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - auto num_threads = PADDLE_CUDA_NUM_THREADS; + int num_threads = PADDLE_CUDA_NUM_THREADS; + + ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( in_data, trg_idx, numel, med_out_data); From 94e72ea6e7eba2f89533225f57626cfed93c0155 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 12 Jun 2018 06:31:01 -0700 Subject: [PATCH 5/8] Support more negative axes in argsort_op --- paddle/fluid/operators/argsort_op.cc | 20 +++++++++++-------- paddle/fluid/operators/argsort_op.cu | 2 +- paddle/fluid/operators/argsort_op.h | 2 +- .../fluid/tests/unittests/test_argsort_op.py | 7 +++++++ 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 2943d409a2..8a44fd12ce 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel { "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " "dimension %d.", axis, num_dims); - PADDLE_ENFORCE(axis >= 0 || axis == -1, - "Attr(axis) %d of ArgsortOp must be nonnegative or equal to " - "-1.", - axis); + PADDLE_ENFORCE(in_dims.size() + axis >= 0, + "Attr(axis) %d of ArgsortOp plus the number of Input(X)'s " + "dimensions %d must be nonnegative.", + axis, in_dims.size()); ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Indices", in_dims); @@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Argsort op."); - AddOutput("Out", "(Tensor) The sorted tensor of Argsort op."); + AddOutput("Out", + "(Tensor) The sorted tensor of Argsort op, with the same " + "shape as Input(X)."); AddOutput("Indices", - "(Tensor) The indices of a tensor giving the sorted order."); + "(Tensor) The indices of a tensor giving the sorted order, with " + "the same shape as Input(X)."); AddComment(R"DOC( Argsort operator @@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). )DOC"); AddAttr("axis", - "(int, default -1) The axis along which to sort the tensor, " - "default -1, the last dimension.") + "(int, default -1) The axis along which to sort the tensor. " + "When axis < 0, the actual axis will be the |axis|'th " + "counting backwards. Default -1, the last dimension.") .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index eac18ea3a0..55ad4ce340 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -103,7 +103,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { int axis = ctx.Attr("axis"); auto in_dims = input->dims(); - axis = (axis == -1) ? (in_dims.size() - 1) : axis; + axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index 51d2b89f94..e13745c494 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -31,7 +31,7 @@ class ArgsortKernel : public framework::OpKernel { int axis = static_cast(ctx.Attr("axis")); auto in_dims = input->dims(); - axis = (axis == -1) ? (in_dims.size() - 1) : axis; + axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index 6995621ba8..1d0aa82a6b 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -21,6 +21,8 @@ class TestArgsortOp(OpTest): def setUp(self): self.init_axis() x = np.random.random((2, 3, 4, 5)).astype("float32") + if self.axis < 0: + self.axis = self.axis + len(x.shape) self.indices = np.argsort(x, kind='quicksort', axis=self.axis) self.out = np.sort(x, kind='quicksort', axis=self.axis) self.op_type = "argsort" @@ -45,5 +47,10 @@ class TestArgsortOpAxis1(TestArgsortOp): self.axis = 1 +class TestArgsortOpAxisNeg2(TestArgsortOp): + def init_axis(self): + self.axis = -2 + + if __name__ == "__main__": unittest.main() From 98460c009eb6a18339097b8ef9be43a216ce1e5f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 12 Jun 2018 09:31:10 -0700 Subject: [PATCH 6/8] Simplify the computation in cpu --- paddle/fluid/operators/argsort_op.h | 51 +++++++++++++++-------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index e13745c494..7e9112cfb7 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -28,47 +28,50 @@ class ArgsortKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); - int axis = static_cast(ctx.Attr("axis")); + int axis = ctx.Attr("axis"); auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* idx_data = indices->mutable_data(ctx.GetPlace()); + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - int64_t part_dims_prod = input->numel() / in_dims[axis]; - for (int64_t i = 0; i < part_dims_prod; ++i) { + int64_t groups = input->numel() / in_dims[axis]; + int64_t stride = (axis == in_dims.size() - 1) + ? 1 + : framework::product(framework::slice_ddim( + in_dims, axis + 1, in_dims.size())); + + for (int64_t i = 0; i < groups; ++i) { int64_t idx = i; - std::vector idx_vec(in_dims.size(), 0); + std::vector shape_vec(in_dims.size(), 0); for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { if (dim != axis) { - idx_vec[dim] = idx % in_dims[dim]; + shape_vec[dim] = idx % in_dims[dim]; idx /= in_dims[dim]; } } - std::vector> in_vec; - std::vector org_index_vec(in_dims[axis], 0); - for (int64_t j = 0; j < in_dims[axis]; ++j) { - idx_vec[axis] = j; - int64_t index = idx_vec[0]; - for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { - index = index * in_dims[dim + 1] + idx_vec[dim + 1]; - } - in_vec.push_back(std::pair(in_data[index], j)); - org_index_vec[j] = index; + + int64_t start_index = shape_vec[0]; + for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { + start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; + } + + std::vector org_index_vec(in_dims[axis], start_index); + for (int64_t j = 1; j < in_dims[axis]; ++j) { + org_index_vec[j] += j * stride; } - std::sort( - in_vec.begin(), in_vec.end(), - [](const std::pair& v1, const std::pair& v2) { - return v1.first < v2.first; - }); + std::sort(org_index_vec.begin(), org_index_vec.end(), + [in_data](const int64_t v1, const int64_t v2) { + return in_data[v1] < in_data[v2]; + }); for (size_t j = 0; j < org_index_vec.size(); ++j) { - int64_t index = org_index_vec[j]; - out_data[index] = in_vec[j].first; - idx_data[index] = in_vec[j].second; + int64_t index = start_index + j * stride; + out_data[index] = in_data[org_index_vec[j]]; + ids_data[index] = (org_index_vec[j] - start_index) / stride; } } } From 92cfa2be3a29cfdd8bf8ffc7fec76221ba761657 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 17 Jun 2018 01:44:36 -0700 Subject: [PATCH 7/8] Avoid using dynamic array in cuda kernel --- paddle/fluid/operators/argsort_op.cc | 4 ++-- paddle/fluid/operators/argsort_op.cu | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 8a44fd12ce..ca9a884b98 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -38,8 +38,8 @@ class ArgsortOp : public framework::OperatorWithKernel { "dimension %d.", axis, num_dims); PADDLE_ENFORCE(in_dims.size() + axis >= 0, - "Attr(axis) %d of ArgsortOp plus the number of Input(X)'s " - "dimensions %d must be nonnegative.", + "Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) " + "must be nonnegative.", axis, in_dims.size()); ctx->SetOutputDim("Out", in_dims); diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 55ad4ce340..fc64e51b34 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -31,8 +31,9 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, int64_t* med_ids) { int64_t index = threadIdx.x + blockDim.x * blockIdx.x; if (index < n) { - int64_t* shape_out_axis = new int64_t[dims_size - 1]; - int64_t* dims_out_axis = new int64_t[dims_size - 1]; + const int max_rank = 9; // Max rank of a tensor allow in Fluid + int64_t shape_out_axis[max_rank - 1] = {0}; + int64_t dims_out_axis[max_rank - 1] = {0}; int64_t tmp = index; int64_t pos_in_axis = 0; int64_t i = dims_size - 2; @@ -57,8 +58,6 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, int64_t traget_idx = group * dim_axis + pos_in_axis; trg_idx[index] = traget_idx; med_ids[traget_idx] = pos_in_axis; - delete[] shape_out_axis; - delete[] dims_out_axis; } } From a523b6f49f4048d8c32c4d6c53dc22fdcebfe2b0 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 18 Jun 2018 02:30:24 -0700 Subject: [PATCH 8/8] Add python api for argsort_op --- doc/fluid/api/layers.rst | 6 ++++ paddle/fluid/operators/argsort_op.cc | 12 +++---- python/paddle/fluid/layers/tensor.py | 51 ++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst index 1f8f636040..4157faae4c 100644 --- a/doc/fluid/api/layers.rst +++ b/doc/fluid/api/layers.rst @@ -1105,6 +1105,12 @@ argmax .. autofunction:: paddle.fluid.layers.argmax :noindex: +argsort +------ + +.. autofunction:: paddle.fluid.layers.argsort + :noindex: + ones ---- diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index ca9a884b98..a2f5a25457 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -34,13 +34,13 @@ class ArgsortOp : public framework::OperatorWithKernel { auto num_dims = in_dims.size(); PADDLE_ENFORCE(axis < num_dims, - "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " - "dimension %d.", + "Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s " + "rank %d.", + axis, num_dims); + PADDLE_ENFORCE(axis >= -num_dims, + "Attr(axis) %d of ArgsortOp must be not less than " + "-rank(Input(X)) (%d).", axis, num_dims); - PADDLE_ENFORCE(in_dims.size() + axis >= 0, - "Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) " - "must be nonnegative.", - axis, in_dims.size()); ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Indices", in_dims); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 149e77b524..656bd5bb1d 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -33,6 +33,7 @@ __all__ = [ 'fill_constant', 'argmin', 'argmax', + 'argsort', 'ones', 'zeros', 'reverse', @@ -438,6 +439,56 @@ def argmax(x, axis=0): return out +def argsort(input, axis=-1): + """ + Performs sorting on the input Variable along the given axis, and outputs + sorted data Varibale and its corresponding index Variable with the same + shape as :attr:`input`. + + .. code-block:: text + + For example, the given axis is -1 and the input Variable + + input = [[0.15849551, 0.45865775, 0.8563702 ], + [0.12070083, 0.28766365, 0.18776911]], + + after argsort, the sorted Vairable becomes + + out = [[0.15849551, 0.45865775, 0.8563702 ], + [0.12070083, 0.18776911, 0.28766365]], + + and the sorted indices along the given axis turn outs to be + + indices = [[0, 1, 2], + [0, 2, 1]] + + Args: + input(Variable): The input Variable for sorting. + axis(int): The axis along which to sort the input Variable. When + :attr:`axis` < 0, the actual axis will be :attr:`axis` + + rank(:attr:`input`). Default -1, the last dimension. + + Returns: + tuple: A tuple of sorted data Variable and the sorted indices. + + Examples: + .. code-block:: python + + input = fluid.layers.data(data=[2, 3]) + out, indices = fluid.layers.argsort(input, axis=0) + """ + helper = LayerHelper("argsort", **locals()) + out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True) + ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True) + helper.append_op( + type='argsort', + inputs={'X': input}, + outputs={'Out': out, + 'Indics': ids}, + attts={'axis': axis}) + return out, ids + + def ones(shape, dtype, force_cpu=False): """ **ones**