add the top v2 for the paddlepaddle api 2.0revert-26856-strategy_example2
parent
f82384113b
commit
286eca2d9e
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,176 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/top_k_v2_op.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class TopkV2Op : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of TopkOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of TopkOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
|
||||
"Output(Indices) of TopkOp should not be null.");
|
||||
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
const int& dim_size = input_dims.size();
|
||||
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
|
||||
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
|
||||
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true,
|
||||
"the axis of topk"
|
||||
"must be [-%d, %d), but you set axis is %d",
|
||||
dim_size, dim_size, axis);
|
||||
|
||||
if (axis < 0) axis += dim_size;
|
||||
|
||||
PADDLE_ENFORCE_GE(
|
||||
k, 1, "the attribute of k in the topk must >= 1, but received %d .", k);
|
||||
PADDLE_ENFORCE_GE(input_dims.size(), 1,
|
||||
"input of topk must have >= 1d shape");
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
input_dims[axis], k,
|
||||
"input of topk op must have >= %d columns in axis of %d", k, axis);
|
||||
}
|
||||
|
||||
framework::DDim dims = input_dims;
|
||||
|
||||
dims[axis] = k;
|
||||
ctx->SetOutputDim("Out", dims);
|
||||
ctx->SetOutputDim("Indices", dims);
|
||||
ctx->ShareLoD("X", "Out");
|
||||
ctx->ShareLoD("X", "Indices");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(),
|
||||
layout_, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class TopkV2OpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) The input of Topk op");
|
||||
AddInput("K",
|
||||
"(Tensor) Number of top elements to look for along "
|
||||
"the last dimension (along each row for matrices).")
|
||||
.AsDispensable();
|
||||
AddOutput("Out", "(Tensor) The output tensor of Topk op");
|
||||
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
|
||||
AddComment(R"DOC(
|
||||
Top K operator
|
||||
|
||||
If the input is a vector (1d tensor), this operator finds the k largest
|
||||
entries in the vector and outputs their values and indices as vectors.
|
||||
Thus values[j] is the j-th largest entry in input, and its index is indices[j].
|
||||
|
||||
For matrices, this operator computes the top k entries in each row. )DOC");
|
||||
AddAttr<int>("k",
|
||||
"(int, default 1) Number of top elements to look for along "
|
||||
"the tensor).")
|
||||
.SetDefault(1);
|
||||
AddAttr<int>("axis",
|
||||
"the axis to sort and get the k indices, value."
|
||||
"if not set, will get k value in last axis.")
|
||||
.SetDefault(-1);
|
||||
AddAttr<bool>("largest",
|
||||
"control flag whether to return largest or smallest")
|
||||
.SetDefault(true);
|
||||
AddAttr<bool>("sorted",
|
||||
"control flag whether to return elements in sorted order")
|
||||
.SetDefault(true);
|
||||
}
|
||||
};
|
||||
|
||||
class TopkV2OpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument("Input(X) should be not null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Indices"), true,
|
||||
platform::errors::InvalidArgument("Input(Indices) should be not null"));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Grad Input(Out) should be not null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out"));
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TopkV2GradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("top_k_v2_grad");
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Indices", this->Output("Indices"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(top_k_v2, ops::TopkV2Op, ops::TopkV2OpMaker,
|
||||
ops::TopkV2GradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(top_k_v2,
|
||||
ops::TopkV2Kernel<paddle::platform::CPUPlace, float>,
|
||||
ops::TopkV2Kernel<paddle::platform::CPUPlace, double>,
|
||||
ops::TopkV2Kernel<paddle::platform::CPUPlace, int32_t>,
|
||||
ops::TopkV2Kernel<paddle::platform::CPUPlace, int64_t>)
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
top_k_v2_grad, ops::TopkV2GradKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::TopkV2GradKernel<paddle::platform::CPUPlace, double>,
|
||||
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int32_t>,
|
||||
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int64_t>)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,244 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
def numpy_topk(x, k=1, axis=-1, largest=True):
|
||||
if axis < 0:
|
||||
axis = len(x.shape) + axis
|
||||
if largest:
|
||||
indices = np.argsort(-x, axis=axis)
|
||||
else:
|
||||
indices = np.argsort(x, axis=axis)
|
||||
if largest:
|
||||
value = -np.sort(-x, axis=axis)
|
||||
else:
|
||||
value = np.sort(x, axis=axis)
|
||||
indices = indices.take(indices=range(0, k), axis=axis)
|
||||
value = value.take(indices=range(0, k), axis=axis)
|
||||
return value, indices
|
||||
|
||||
|
||||
class TestTopkOp(OpTest):
|
||||
def init_args(self):
|
||||
self.k = 3
|
||||
self.axis = 1
|
||||
self.largest = True
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "top_k_v2"
|
||||
self.dtype = np.float64
|
||||
self.input_data = np.random.rand(10, 20)
|
||||
self.init_args()
|
||||
self.inputs = {'X': self.input_data}
|
||||
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
|
||||
output, indices = numpy_topk(
|
||||
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
|
||||
self.outputs = {'Out': output, 'Indices': indices}
|
||||
|
||||
def test_check_output(self):
|
||||
paddle.enable_static()
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
paddle.enable_static()
|
||||
self.check_grad(set(['X']), 'Out')
|
||||
|
||||
|
||||
class TestTopOp1(TestTopkOp):
|
||||
def init_args(self):
|
||||
self.k = 3
|
||||
self.axis = 0
|
||||
self.largest = True
|
||||
|
||||
|
||||
class TestTopOp2(TestTopkOp):
|
||||
def init_args(self):
|
||||
self.k = 3
|
||||
self.axis = 0
|
||||
self.largest = False
|
||||
|
||||
|
||||
class TestTopOp3(TestTopkOp):
|
||||
def init_args(self):
|
||||
self.k = 4
|
||||
self.axis = 0
|
||||
self.largest = False
|
||||
|
||||
|
||||
class TestTopOp4(TestTopkOp):
|
||||
def init_args(self):
|
||||
self.k = 4
|
||||
self.axis = 0
|
||||
self.largest = False
|
||||
|
||||
|
||||
class TestTopkOp5(TestTopkOp):
|
||||
def init_args(self):
|
||||
self.k = 3
|
||||
self.axis = 1
|
||||
self.largest = True
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "top_k_v2"
|
||||
self.dtype = np.float64
|
||||
self.input_data = np.random.rand(10, 10, 5)
|
||||
self.init_args()
|
||||
self.inputs = {'X': self.input_data}
|
||||
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
|
||||
output, indices = numpy_topk(
|
||||
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
|
||||
self.outputs = {'Out': output, 'Indices': indices}
|
||||
|
||||
|
||||
class TestTopkOp6(TestTopkOp):
|
||||
def init_args(self):
|
||||
self.k = 3
|
||||
self.axis = 1
|
||||
self.largest = True
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "top_k_v2"
|
||||
self.dtype = np.float64
|
||||
self.input_data = np.random.rand(10, 10, 5)
|
||||
self.init_args()
|
||||
self.inputs = {'X': self.input_data}
|
||||
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
|
||||
output, indices = numpy_topk(
|
||||
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
|
||||
self.outputs = {'Out': output, 'Indices': indices}
|
||||
|
||||
|
||||
class TestTopKAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
np.random.seed(123)
|
||||
self.input_data = np.random.rand(6, 7, 8)
|
||||
self.large_input_data = np.random.rand(2, 1030)
|
||||
|
||||
def run_dygraph(self, place):
|
||||
paddle.disable_static(place)
|
||||
input_tensor = paddle.to_tensor(self.input_data)
|
||||
large_input_tensor = paddle.to_tensor(self.large_input_data)
|
||||
# test case for basic test case 1
|
||||
paddle_result = paddle.topk(input_tensor, k=2)
|
||||
numpy_result = numpy_topk(self.input_data, k=2)
|
||||
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
|
||||
# test case for basic test case 2 with axis
|
||||
paddle_result = paddle.topk(input_tensor, k=2, axis=1)
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
|
||||
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
|
||||
# test case for basic test case 3 with tensor K
|
||||
k_tensor = paddle.to_tensor(np.array([2]))
|
||||
paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1)
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
|
||||
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
|
||||
# test case for basic test case 4 with tensor largest
|
||||
k_tensor = paddle.to_tensor(np.array([2]))
|
||||
paddle_result = paddle.topk(input_tensor, k=2, axis=1, largest=False)
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=1, largest=False)
|
||||
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
|
||||
# test case for basic test case 5 with axis -1
|
||||
k_tensor = paddle.to_tensor(np.array([2]))
|
||||
paddle_result = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=-1, largest=False)
|
||||
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
|
||||
# test case for basic test case 6 for the partial sort
|
||||
paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1)
|
||||
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
|
||||
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
|
||||
# test case for basic test case 7 for the unsorted
|
||||
paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
|
||||
sort_paddle = numpy_topk(
|
||||
np.array(paddle_result[0].numpy()), axis=1, k=2)
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
|
||||
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
|
||||
|
||||
def run_static(self, place):
|
||||
paddle.enable_static()
|
||||
with paddle.static.program_guard(paddle.static.Program(),
|
||||
paddle.static.Program()):
|
||||
input_tensor = paddle.static.data(
|
||||
name="x", shape=[6, 7, 8], dtype="float64")
|
||||
large_input_tensor = paddle.static.data(
|
||||
name="large_x", shape=[2, 1030], dtype="float64")
|
||||
k_tensor = paddle.static.data(name="k", shape=[1], dtype="int32")
|
||||
result1 = paddle.topk(input_tensor, k=2)
|
||||
result2 = paddle.topk(input_tensor, k=2, axis=-1)
|
||||
result3 = paddle.topk(input_tensor, k=k_tensor, axis=1)
|
||||
result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False)
|
||||
result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
|
||||
result6 = paddle.topk(large_input_tensor, k=1, axis=-1)
|
||||
result7 = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
|
||||
exe = paddle.static.Executor(place)
|
||||
input_data = np.random.rand(10, 20).astype("float64")
|
||||
large_input_data = np.random.rand(2, 100).astype("float64")
|
||||
paddle_result = exe.run(
|
||||
feed={
|
||||
"x": self.input_data,
|
||||
"large_x": self.large_input_data,
|
||||
"k": np.array([2]).astype("int32")
|
||||
},
|
||||
fetch_list=[
|
||||
result1[0], result1[1], result2[0], result2[1], result3[0],
|
||||
result3[1], result4[0], result4[1], result5[0], result5[1],
|
||||
result6[0], result6[1], result7[0], result7[1]
|
||||
])
|
||||
numpy_result = numpy_topk(self.input_data, k=2)
|
||||
self.assertTrue(np.allclose(paddle_result[0], numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[1], numpy_result[1]))
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=-1)
|
||||
self.assertTrue(np.allclose(paddle_result[2], numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[3], numpy_result[1]))
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
|
||||
self.assertTrue(np.allclose(paddle_result[4], numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[5], numpy_result[1]))
|
||||
numpy_result = numpy_topk(
|
||||
self.input_data, k=2, axis=1, largest=False)
|
||||
self.assertTrue(np.allclose(paddle_result[6], numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[7], numpy_result[1]))
|
||||
numpy_result = numpy_topk(
|
||||
self.input_data, k=2, axis=-1, largest=False)
|
||||
self.assertTrue(np.allclose(paddle_result[8], numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[9], numpy_result[1]))
|
||||
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
|
||||
self.assertTrue(np.allclose(paddle_result[10], numpy_result[0]))
|
||||
self.assertTrue(np.allclose(paddle_result[11], numpy_result[1]))
|
||||
sort_paddle = numpy_topk(paddle_result[12], axis=1, k=2)
|
||||
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
|
||||
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
|
||||
|
||||
def test_cases(self):
|
||||
places = [core.CPUPlace()]
|
||||
if core.is_compiled_with_cuda():
|
||||
places.append(core.CUDAPlace(0))
|
||||
for place in places:
|
||||
self.run_dygraph(place)
|
||||
self.run_static(place)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue