Polygon box transform op for OCR East detection. (#10802)
* Add quad transform. * Fix some syntax error. * Fix CUDA kernel launch configure. * Generalize geometry channels. * Rename QuadTransform to PolygonRestore. * Rename op. * Rename op and fix computation. * Modify CMakeLists.txt for box_restore op. * Refine code: 1. rename op 2. uncomment unitest on GPUshanyi15-patch-3
parent
a62bbd1ddc
commit
376c948e88
@ -0,0 +1,105 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class PolygonBoxTransformCPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||
"It must use CUDAPlace.");
|
||||
auto* in = ctx.Input<Tensor>("Input");
|
||||
auto in_dims = in->dims();
|
||||
const T* in_data = in->data<T>();
|
||||
auto* out = ctx.Output<Tensor>("Output");
|
||||
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int batch_size = in_dims[0];
|
||||
int geo_channel = in_dims[1];
|
||||
int height = in_dims[2];
|
||||
int width = in_dims[3];
|
||||
int id = 0;
|
||||
for (int id_n = 0; id_n < batch_size * geo_channel; ++id_n) {
|
||||
for (int id_h = 0; id_h < height; ++id_h) {
|
||||
for (int id_w = 0; id_w < width; ++id_w) {
|
||||
id = id_n * height * width + width * id_h + id_w;
|
||||
if (id_n % 2 == 0) {
|
||||
out_data[id] = id_w - in_data[id];
|
||||
} else {
|
||||
out_data[id] = id_h - in_data[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class PolygonBoxTransformOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Input"),
|
||||
"Input (Input) of polygon_box transform op should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("Output"),
|
||||
"Output (Output) of polygon_box transform op should not be null.");
|
||||
|
||||
auto in_dim = ctx->GetInputDim("Input");
|
||||
|
||||
PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4.");
|
||||
PADDLE_ENFORCE_EQ(in_dim[1] % 2, 0,
|
||||
"input's second dimension must be even.");
|
||||
|
||||
ctx->SetOutputDim("Output", in_dim);
|
||||
}
|
||||
};
|
||||
|
||||
class PolygonBoxTransformOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"Input",
|
||||
"The input with shape [batch_size, geometry_channels, height, width]");
|
||||
AddOutput("Output", "The output with the same shape as input");
|
||||
|
||||
AddComment(R"DOC(
|
||||
PolygonBoxTransform Operator.
|
||||
The input is the final geometry output in detection network.
|
||||
We use 2*n numbers to denote the coordinate shift from n corner vertices of
|
||||
the polygon_box to the pixel location. As each distance offset contains two numbers (xi, yi),
|
||||
the geometry output contains 2*n channels.
|
||||
PolygonBoxTransform Operator is used to transform the coordinate shift to the real coordinate.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(polygon_box_transform, ops::PolygonBoxTransformOp,
|
||||
ops::PolygonBoxTransformOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
polygon_box_transform,
|
||||
ops::PolygonBoxTransformCPUKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::PolygonBoxTransformCPUKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,76 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
#define CUDA_BLOCK_SIZE 16
|
||||
|
||||
template <typename T>
|
||||
__global__ void PolygonBoxTransformKernel(const int n, const int h, const int w,
|
||||
const T* input, T* output) {
|
||||
int id_n = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
int id_h = threadIdx.y + blockDim.y * blockIdx.y;
|
||||
int id_w = threadIdx.z + blockDim.z * blockIdx.z;
|
||||
if (id_n < n && id_h < h && id_w < w) {
|
||||
int id = id_n * h * w + w * id_h + id_w;
|
||||
if (id_n % 2 == 0) {
|
||||
output[id] = id_w - input[id];
|
||||
} else {
|
||||
output[id] = id_h - input[id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class PolygonBoxTransformOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use CUDAPlace.");
|
||||
auto* in = ctx.Input<Tensor>("Input");
|
||||
auto in_dims = in->dims();
|
||||
const T* in_data = in->data<T>();
|
||||
auto* out = ctx.Output<Tensor>("Output");
|
||||
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int batch_size = in_dims[0];
|
||||
int geo_channels = in_dims[1];
|
||||
int height = in_dims[2];
|
||||
int width = in_dims[3];
|
||||
dim3 threadsPerBlock(
|
||||
PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE),
|
||||
CUDA_BLOCK_SIZE, CUDA_BLOCK_SIZE);
|
||||
dim3 numBlocks((batch_size * geo_channels) / threadsPerBlock.x,
|
||||
(height + threadsPerBlock.y - 1) / threadsPerBlock.y,
|
||||
(width + threadsPerBlock.z - 1) / threadsPerBlock.z);
|
||||
auto stream = ctx.cuda_device_context().stream();
|
||||
PolygonBoxTransformKernel<T><<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
batch_size * geo_channels, height, width, in_data, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
polygon_box_transform,
|
||||
paddle::operators::PolygonBoxTransformOpCUDAKernel<float>,
|
||||
paddle::operators::PolygonBoxTransformOpCUDAKernel<double>);
|
@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def PolygonBoxRestore(input):
|
||||
shape = input.shape
|
||||
batch_size = shape[0]
|
||||
geo_channels = shape[1]
|
||||
h = shape[2]
|
||||
w = shape[3]
|
||||
h_indexes = np.array(range(h) * w).reshape(
|
||||
[w, h]).transpose()[np.newaxis, :] # [1, h, w]
|
||||
w_indexes = np.array(range(w) * h).reshape(
|
||||
[h, w])[np.newaxis, :] # [1, h, w]
|
||||
indexes = np.concatenate(
|
||||
(w_indexes, h_indexes))[np.newaxis, :] # [1, 2, h, w]
|
||||
indexes = indexes.repeat(
|
||||
[geo_channels / 2],
|
||||
axis=0)[np.newaxis, :] # [1, geo_channels/2, 2, h, w]
|
||||
indexes = indexes.repeat(
|
||||
[batch_size], axis=0) # [batch_size, geo_channels/2, 2, h, w]
|
||||
return indexes.reshape(
|
||||
input.shape) - input # [batch_size, geo_channels, h, w]
|
||||
|
||||
|
||||
class TestPolygonBoxRestoreOp(OpTest):
|
||||
def config(self):
|
||||
self.input_shape = (1, 8, 2, 2)
|
||||
|
||||
def setUp(self):
|
||||
self.config()
|
||||
self.op_type = "polygon_box_transform"
|
||||
input = np.random.random(self.input_shape).astype("float32")
|
||||
self.inputs = {'Input': input}
|
||||
output = PolygonBoxRestore(input)
|
||||
self.outputs = {'Output': output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestCase1(TestPolygonBoxRestoreOp):
|
||||
def config(self):
|
||||
self.input_shape = (2, 10, 3, 2)
|
||||
|
||||
|
||||
class TestCase2(TestPolygonBoxRestoreOp):
|
||||
def config(self):
|
||||
self.input_shape = (3, 12, 4, 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue