Add topk op (#3760)
* init add * add topk op * someupdate * fix style check * add test py file * update top k cuda kernel * follow comments * remove debug print * fix casting error * fix casting error * fix casting error * fix rename bug... * fix travisAdaptive_data_structure_for_SwitchOrderLayer
parent
2f40da0923
commit
3fbb692d4b
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/top_k_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class TopkOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||||
|
"Input of TopkOP must be initialized.");
|
||||||
|
auto *input = ctx.Input<framework::Tensor>("X");
|
||||||
|
const int k = static_cast<int>(ctx.Attr<int>("k"));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
|
||||||
|
PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
|
||||||
|
PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
|
||||||
|
"input must have >= k columns");
|
||||||
|
|
||||||
|
framework::DDim dims = input->dims();
|
||||||
|
dims[dims.size() - 1] = k;
|
||||||
|
ctx.Output<Tensor>("Out")->Resize(dims);
|
||||||
|
ctx.Output<Tensor>("Indices")->Resize(dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("X", "The input of Topk op");
|
||||||
|
AddOutput("Out", "The output tensor of Topk op");
|
||||||
|
AddOutput("Indices", "The indices of Topk elements of input");
|
||||||
|
AddComment(
|
||||||
|
R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
|
||||||
|
|
||||||
|
For matrices, computes the top k entries in each row. )DOC");
|
||||||
|
AddAttr<int>("k",
|
||||||
|
"Number of top elements to look for along the last "
|
||||||
|
"dimension (along each row for matrices).")
|
||||||
|
.SetDefault(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
|
||||||
|
REGISTER_OP_CPU_KERNEL(top_k,
|
||||||
|
ops::TopkKernel<paddle::platform::CPUPlace, float>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include "paddle/framework/eigen.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class TopkKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
// Get the top k elements of each row of input tensor
|
||||||
|
// FIXME: only deal with matrix(2d tensor).
|
||||||
|
auto* input = ctx.Input<Tensor>("X");
|
||||||
|
auto* output = ctx.Output<Tensor>("Out");
|
||||||
|
auto* indices = ctx.Output<Tensor>("Indices");
|
||||||
|
// k is determined by Attr
|
||||||
|
const size_t k = static_cast<int>(ctx.Attr<int>("k"));
|
||||||
|
|
||||||
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
||||||
|
T* indices_data = indices->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
auto eg_input = EigenMatrix<T>::From(*input);
|
||||||
|
|
||||||
|
// reshape input to a flattern matrix(like flat_inner_dims)
|
||||||
|
framework::DDim inputdims = input->dims();
|
||||||
|
const size_t row = framework::product(
|
||||||
|
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
|
||||||
|
const size_t col = inputdims[inputdims.size() - 1];
|
||||||
|
Eigen::DSizes<int, 2> flat2dims(row, col);
|
||||||
|
// NOTE: eigen shape doesn't affect paddle tensor.
|
||||||
|
eg_input.reshape(flat2dims);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < row; i++) {
|
||||||
|
std::vector<std::pair<T, size_t>> vec;
|
||||||
|
for (size_t j = 0; j < col; j++) {
|
||||||
|
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::partial_sort(
|
||||||
|
vec.begin(), vec.begin() + k, vec.end(),
|
||||||
|
[](const std::pair<T, size_t>& l, const std::pair<T, size_t>& r) {
|
||||||
|
return l.first > r.first;
|
||||||
|
});
|
||||||
|
for (size_t j = 0; j < k; j++) {
|
||||||
|
output_data[i * k + j] = vec[j].first;
|
||||||
|
indices_data[i * k + j] = vec[j].second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,52 @@
|
|||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from gradient_checker import GradientChecker, create_op
|
||||||
|
from op_test_util import OpTestMeta
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopkOp(unittest.TestCase):
|
||||||
|
__metaclass__ = OpTestMeta
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.type = "top_k"
|
||||||
|
k = 1
|
||||||
|
input = np.random.random((32, 84)).astype("float32")
|
||||||
|
output = np.ndarray((32, k))
|
||||||
|
indices = np.ndarray((32, k))
|
||||||
|
|
||||||
|
self.inputs = {'X': input}
|
||||||
|
self.attrs = {'k': k}
|
||||||
|
|
||||||
|
for rowid in xrange(32):
|
||||||
|
row = input[rowid]
|
||||||
|
output[rowid] = np.sort(row)[-k:]
|
||||||
|
indices[rowid] = row.argsort()[-k:]
|
||||||
|
|
||||||
|
self.outputs = {'Out': output, 'Indices': indices}
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopkOp3d(unittest.TestCase):
|
||||||
|
__metaclass__ = OpTestMeta
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.type = "top_k"
|
||||||
|
k = 1
|
||||||
|
input = np.random.random((32, 2, 84)).astype("float32")
|
||||||
|
input_flat_2d = input.reshape(64, 84)
|
||||||
|
output = np.ndarray((64, k))
|
||||||
|
indices = np.ndarray((64, k)).astype("int")
|
||||||
|
|
||||||
|
# FIXME: should use 'X': input for a 3d input
|
||||||
|
self.inputs = {'X': input_flat_2d}
|
||||||
|
self.attrs = {'k': k}
|
||||||
|
|
||||||
|
for rowid in xrange(64):
|
||||||
|
row = input_flat_2d[rowid]
|
||||||
|
output[rowid] = np.sort(row)[-k:]
|
||||||
|
indices[rowid] = row.argsort()[-k:]
|
||||||
|
|
||||||
|
self.outputs = {'Out': output, 'Indices': indices}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue