Merge pull request #15356 from jerrywgz/add_clip_op
Add box clip oprevert-15296-async_double_buffered_py_reader
commit
1743d1a58f
@ -0,0 +1,86 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/box_clip_op.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BoxClipOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of BoxClipOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
|
||||
"Input(ImInfo) of BoxClipOp should not be null.");
|
||||
|
||||
auto input_box_dims = ctx->GetInputDim("Input");
|
||||
auto im_info_dims = ctx->GetInputDim("ImInfo");
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
auto input_box_size = input_box_dims.size();
|
||||
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
|
||||
"The last dimension of Input must be 4");
|
||||
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
|
||||
"The rank of Input(Input) in BoxClipOp must be 2");
|
||||
PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
|
||||
"The last dimension of ImInfo must be 3");
|
||||
}
|
||||
ctx->ShareDim("Input", /*->*/ "Output");
|
||||
ctx->ShareLoD("Input", /*->*/ "Output");
|
||||
}
|
||||
};
|
||||
|
||||
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Input",
|
||||
"(LoDTensor) "
|
||||
"Input is a LoDTensor with shape [..., 4] holds 4 points"
|
||||
"in last dimension in format [xmin, ymin, xmax, ymax]");
|
||||
AddInput("ImInfo",
|
||||
"(Tensor) Information for image reshape is in shape (N, 3), "
|
||||
"in format (height, width, im_scale)");
|
||||
AddOutput("Output",
|
||||
"(LoDTensor) "
|
||||
"Output is a LoDTensor with the same shape as Input"
|
||||
"and it is the result after clip");
|
||||
AddComment(R"DOC(
|
||||
This operator clips input boxes to original input images.
|
||||
|
||||
For each input box, The formula is given as follows:
|
||||
|
||||
$$xmin = \max(\min(xmin, im_w - 1), 0)$$
|
||||
$$ymin = \max(\min(ymin, im_h - 1), 0)$$
|
||||
$$xmax = \max(\min(xmax, im_w - 1), 0)$$
|
||||
$$ymax = \max(\min(ymax, im_h - 1), 0)$$
|
||||
|
||||
where im_w and im_h are computed from ImInfo, the formula is given as follows:
|
||||
|
||||
$$im_w = \round(width / im_scale)$$
|
||||
$$im_h = \round(height / im_scale)$$
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
box_clip, ops::BoxClipKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::BoxClipKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,74 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detection/box_clip_op.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTenso = framework::LoDTensor;
|
||||
|
||||
static constexpr int ImInfoSize = 3;
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
static __global__ void GPUBoxClip(const T *input, const size_t *lod,
|
||||
const size_t width, const T *im_info,
|
||||
T *output) {
|
||||
T im_w = round(im_info[blockIdx.x * ImInfoSize + 1] /
|
||||
im_info[blockIdx.x * ImInfoSize + 2]);
|
||||
T im_h = round(im_info[blockIdx.x * ImInfoSize] /
|
||||
im_info[blockIdx.x * ImInfoSize + 2]);
|
||||
for (int i = threadIdx.x; i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * width;
|
||||
i += BlockSize) {
|
||||
int idx = lod[blockIdx.x] * width + i;
|
||||
T im_size = (idx % 2 == 0) ? im_w : im_h;
|
||||
output[idx] = max(min(input[idx], im_size - 1), T(0.));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GPUBoxClipKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto *input = context.Input<LoDTensor>("Input");
|
||||
auto *im_info = context.Input<Tensor>("ImInfo");
|
||||
auto *output = context.Output<LoDTensor>("Output");
|
||||
const int64_t num = input->dims()[0];
|
||||
const int64_t bbox_width = input->numel() / num;
|
||||
auto lod = input->lod();
|
||||
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
|
||||
auto &dev_ctx = context.template device_context<DeviceContext>();
|
||||
auto stream = dev_ctx.stream();
|
||||
const size_t batch_size = lod.back().size() - 1;
|
||||
T *output_data = output->mutable_data<T>(dev_ctx.GetPlace());
|
||||
GPUBoxClip<T, 512><<<batch_size, 512, 0, stream>>>(
|
||||
input->data<T>(), abs_offset_lod[0].CUDAMutableData(dev_ctx.GetPlace()),
|
||||
bbox_width, im_info->data<T>(), output_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
box_clip, ops::GPUBoxClipKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::GPUBoxClipKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detection/bbox_util.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BoxClipKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input_box = context.Input<LoDTensor>("Input");
|
||||
auto* im_info = context.Input<LoDTensor>("ImInfo");
|
||||
auto* output_box = context.Output<LoDTensor>("Output");
|
||||
auto& dev_ctx =
|
||||
context.template device_context<platform::CPUDeviceContext>();
|
||||
output_box->mutable_data<T>(context.GetPlace());
|
||||
if (input_box->lod().size()) {
|
||||
PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL,
|
||||
"Only support 1 level of LoD.");
|
||||
}
|
||||
auto box_lod = input_box->lod().back();
|
||||
int64_t n = static_cast<int64_t>(box_lod.size() - 1);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
Tensor im_info_slice = im_info->Slice(i, i + 1);
|
||||
Tensor box_slice = input_box->Slice(box_lod[i], box_lod[i + 1]);
|
||||
Tensor output_slice = output_box->Slice(box_lod[i], box_lod[i + 1]);
|
||||
ClipTiledBoxes<T>(dev_ctx, im_info_slice, box_slice, &output_slice);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
import copy
|
||||
|
||||
|
||||
def box_clip(input_box, im_info, output_box):
|
||||
im_w = round(im_info[1] / im_info[2])
|
||||
im_h = round(im_info[0] / im_info[2])
|
||||
output_box[:, :, 0] = np.maximum(
|
||||
np.minimum(input_box[:, :, 0], im_w - 1), 0)
|
||||
output_box[:, :, 1] = np.maximum(
|
||||
np.minimum(input_box[:, :, 1], im_h - 1), 0)
|
||||
output_box[:, :, 2] = np.maximum(
|
||||
np.minimum(input_box[:, :, 2], im_w - 1), 0)
|
||||
output_box[:, :, 3] = np.maximum(
|
||||
np.minimum(input_box[:, :, 3], im_h - 1), 0)
|
||||
|
||||
|
||||
def batch_box_clip(input_boxes, im_info, lod):
|
||||
n = input_boxes.shape[0]
|
||||
m = input_boxes.shape[1]
|
||||
output_boxes = np.zeros((n, m, 4), dtype=np.float32)
|
||||
cur_offset = 0
|
||||
for i in range(len(lod)):
|
||||
box_clip(input_boxes[cur_offset:(cur_offset + lod[i]), :, :],
|
||||
im_info[i, :],
|
||||
output_boxes[cur_offset:(cur_offset + lod[i]), :, :])
|
||||
cur_offset += lod[i]
|
||||
return output_boxes
|
||||
|
||||
|
||||
class TestBoxClipOp(OpTest):
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "box_clip"
|
||||
lod = [[1, 2, 3]]
|
||||
input_boxes = np.random.random((6, 10, 4)) * 5
|
||||
im_info = np.array([[5, 8, 1.], [6, 6, 1.], [7, 5, 1.]])
|
||||
output_boxes = batch_box_clip(input_boxes, im_info, lod[0])
|
||||
|
||||
self.inputs = {
|
||||
'Input': (input_boxes.astype('float32'), lod),
|
||||
'ImInfo': im_info.astype('float32'),
|
||||
}
|
||||
self.outputs = {'Output': output_boxes}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue