Merge pull request #15356 from jerrywgz/add_clip_op
	
		
	
				
					
				
			Add box clip oprevert-15296-async_double_buffered_py_reader
						commit
						1743d1a58f
					
				@ -0,0 +1,86 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/detection/box_clip_op.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
class BoxClipOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("Input"),
 | 
				
			||||
                   "Input(Input) of BoxClipOp should not be null.");
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
 | 
				
			||||
                   "Input(ImInfo) of BoxClipOp should not be null.");
 | 
				
			||||
 | 
				
			||||
    auto input_box_dims = ctx->GetInputDim("Input");
 | 
				
			||||
    auto im_info_dims = ctx->GetInputDim("ImInfo");
 | 
				
			||||
 | 
				
			||||
    if (ctx->IsRuntime()) {
 | 
				
			||||
      auto input_box_size = input_box_dims.size();
 | 
				
			||||
      PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
 | 
				
			||||
                        "The last dimension of Input must be 4");
 | 
				
			||||
      PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
 | 
				
			||||
                        "The rank of Input(Input) in BoxClipOp must be 2");
 | 
				
			||||
      PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
 | 
				
			||||
                        "The last dimension of ImInfo must be 3");
 | 
				
			||||
    }
 | 
				
			||||
    ctx->ShareDim("Input", /*->*/ "Output");
 | 
				
			||||
    ctx->ShareLoD("Input", /*->*/ "Output");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  void Make() override {
 | 
				
			||||
    AddInput("Input",
 | 
				
			||||
             "(LoDTensor) "
 | 
				
			||||
             "Input is a LoDTensor with shape [..., 4] holds 4 points"
 | 
				
			||||
             "in last dimension in format [xmin, ymin, xmax, ymax]");
 | 
				
			||||
    AddInput("ImInfo",
 | 
				
			||||
             "(Tensor) Information for image reshape is in shape (N, 3), "
 | 
				
			||||
             "in format (height, width, im_scale)");
 | 
				
			||||
    AddOutput("Output",
 | 
				
			||||
              "(LoDTensor) "
 | 
				
			||||
              "Output is a LoDTensor with the same shape as Input"
 | 
				
			||||
              "and it is the result after clip");
 | 
				
			||||
    AddComment(R"DOC(
 | 
				
			||||
This operator clips input boxes to original input images.
 | 
				
			||||
 | 
				
			||||
For each input box, The formula is given as follows:
 | 
				
			||||
 | 
				
			||||
       $$xmin = \max(\min(xmin, im_w - 1), 0)$$
 | 
				
			||||
       $$ymin = \max(\min(ymin, im_h - 1), 0)$$     
 | 
				
			||||
       $$xmax = \max(\min(xmax, im_w - 1), 0)$$
 | 
				
			||||
       $$ymax = \max(\min(ymax, im_h - 1), 0)$$
 | 
				
			||||
 | 
				
			||||
where im_w and im_h are computed from ImInfo, the formula is given as follows:
 | 
				
			||||
 | 
				
			||||
       $$im_w = \round(width / im_scale)$$
 | 
				
			||||
       $$im_h = \round(height / im_scale)$$ 
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker,
 | 
				
			||||
                  paddle::framework::EmptyGradOpMaker);
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(
 | 
				
			||||
    box_clip, ops::BoxClipKernel<paddle::platform::CPUDeviceContext, float>,
 | 
				
			||||
    ops::BoxClipKernel<paddle::platform::CPUDeviceContext, double>);
 | 
				
			||||
@ -0,0 +1,74 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include <algorithm>
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/operators/detection/box_clip_op.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/math_function.h"
 | 
				
			||||
#include "paddle/fluid/platform/cuda_primitives.h"
 | 
				
			||||
#include "paddle/fluid/platform/hostdevice.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
using LoDTenso = framework::LoDTensor;
 | 
				
			||||
 | 
				
			||||
static constexpr int ImInfoSize = 3;
 | 
				
			||||
 | 
				
			||||
template <typename T, int BlockSize>
 | 
				
			||||
static __global__ void GPUBoxClip(const T *input, const size_t *lod,
 | 
				
			||||
                                  const size_t width, const T *im_info,
 | 
				
			||||
                                  T *output) {
 | 
				
			||||
  T im_w = round(im_info[blockIdx.x * ImInfoSize + 1] /
 | 
				
			||||
                 im_info[blockIdx.x * ImInfoSize + 2]);
 | 
				
			||||
  T im_h = round(im_info[blockIdx.x * ImInfoSize] /
 | 
				
			||||
                 im_info[blockIdx.x * ImInfoSize + 2]);
 | 
				
			||||
  for (int i = threadIdx.x; i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * width;
 | 
				
			||||
       i += BlockSize) {
 | 
				
			||||
    int idx = lod[blockIdx.x] * width + i;
 | 
				
			||||
    T im_size = (idx % 2 == 0) ? im_w : im_h;
 | 
				
			||||
    output[idx] = max(min(input[idx], im_size - 1), T(0.));
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class GPUBoxClipKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext &context) const override {
 | 
				
			||||
    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
 | 
				
			||||
                   "This kernel only runs on GPU device.");
 | 
				
			||||
    auto *input = context.Input<LoDTensor>("Input");
 | 
				
			||||
    auto *im_info = context.Input<Tensor>("ImInfo");
 | 
				
			||||
    auto *output = context.Output<LoDTensor>("Output");
 | 
				
			||||
    const int64_t num = input->dims()[0];
 | 
				
			||||
    const int64_t bbox_width = input->numel() / num;
 | 
				
			||||
    auto lod = input->lod();
 | 
				
			||||
    framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
 | 
				
			||||
    auto &dev_ctx = context.template device_context<DeviceContext>();
 | 
				
			||||
    auto stream = dev_ctx.stream();
 | 
				
			||||
    const size_t batch_size = lod.back().size() - 1;
 | 
				
			||||
    T *output_data = output->mutable_data<T>(dev_ctx.GetPlace());
 | 
				
			||||
    GPUBoxClip<T, 512><<<batch_size, 512, 0, stream>>>(
 | 
				
			||||
        input->data<T>(), abs_offset_lod[0].CUDAMutableData(dev_ctx.GetPlace()),
 | 
				
			||||
        bbox_width, im_info->data<T>(), output_data);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(
 | 
				
			||||
    box_clip, ops::GPUBoxClipKernel<paddle::platform::CUDADeviceContext, float>,
 | 
				
			||||
    ops::GPUBoxClipKernel<paddle::platform::CUDADeviceContext, double>);
 | 
				
			||||
@ -0,0 +1,50 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
#include <string>
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/operators/detection/bbox_util.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/math_function.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
using LoDTensor = framework::LoDTensor;
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class BoxClipKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& context) const override {
 | 
				
			||||
    auto* input_box = context.Input<LoDTensor>("Input");
 | 
				
			||||
    auto* im_info = context.Input<LoDTensor>("ImInfo");
 | 
				
			||||
    auto* output_box = context.Output<LoDTensor>("Output");
 | 
				
			||||
    auto& dev_ctx =
 | 
				
			||||
        context.template device_context<platform::CPUDeviceContext>();
 | 
				
			||||
    output_box->mutable_data<T>(context.GetPlace());
 | 
				
			||||
    if (input_box->lod().size()) {
 | 
				
			||||
      PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL,
 | 
				
			||||
                        "Only support 1 level of LoD.");
 | 
				
			||||
    }
 | 
				
			||||
    auto box_lod = input_box->lod().back();
 | 
				
			||||
    int64_t n = static_cast<int64_t>(box_lod.size() - 1);
 | 
				
			||||
    for (int i = 0; i < n; ++i) {
 | 
				
			||||
      Tensor im_info_slice = im_info->Slice(i, i + 1);
 | 
				
			||||
      Tensor box_slice = input_box->Slice(box_lod[i], box_lod[i + 1]);
 | 
				
			||||
      Tensor output_slice = output_box->Slice(box_lod[i], box_lod[i + 1]);
 | 
				
			||||
      ClipTiledBoxes<T>(dev_ctx, im_info_slice, box_slice, &output_slice);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,70 @@
 | 
				
			||||
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
import sys
 | 
				
			||||
import math
 | 
				
			||||
from op_test import OpTest
 | 
				
			||||
import copy
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def box_clip(input_box, im_info, output_box):
 | 
				
			||||
    im_w = round(im_info[1] / im_info[2])
 | 
				
			||||
    im_h = round(im_info[0] / im_info[2])
 | 
				
			||||
    output_box[:, :, 0] = np.maximum(
 | 
				
			||||
        np.minimum(input_box[:, :, 0], im_w - 1), 0)
 | 
				
			||||
    output_box[:, :, 1] = np.maximum(
 | 
				
			||||
        np.minimum(input_box[:, :, 1], im_h - 1), 0)
 | 
				
			||||
    output_box[:, :, 2] = np.maximum(
 | 
				
			||||
        np.minimum(input_box[:, :, 2], im_w - 1), 0)
 | 
				
			||||
    output_box[:, :, 3] = np.maximum(
 | 
				
			||||
        np.minimum(input_box[:, :, 3], im_h - 1), 0)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def batch_box_clip(input_boxes, im_info, lod):
 | 
				
			||||
    n = input_boxes.shape[0]
 | 
				
			||||
    m = input_boxes.shape[1]
 | 
				
			||||
    output_boxes = np.zeros((n, m, 4), dtype=np.float32)
 | 
				
			||||
    cur_offset = 0
 | 
				
			||||
    for i in range(len(lod)):
 | 
				
			||||
        box_clip(input_boxes[cur_offset:(cur_offset + lod[i]), :, :],
 | 
				
			||||
                 im_info[i, :],
 | 
				
			||||
                 output_boxes[cur_offset:(cur_offset + lod[i]), :, :])
 | 
				
			||||
        cur_offset += lod[i]
 | 
				
			||||
    return output_boxes
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestBoxClipOp(OpTest):
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.op_type = "box_clip"
 | 
				
			||||
        lod = [[1, 2, 3]]
 | 
				
			||||
        input_boxes = np.random.random((6, 10, 4)) * 5
 | 
				
			||||
        im_info = np.array([[5, 8, 1.], [6, 6, 1.], [7, 5, 1.]])
 | 
				
			||||
        output_boxes = batch_box_clip(input_boxes, im_info, lod[0])
 | 
				
			||||
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'Input': (input_boxes.astype('float32'), lod),
 | 
				
			||||
            'ImInfo': im_info.astype('float32'),
 | 
				
			||||
        }
 | 
				
			||||
        self.outputs = {'Output': output_boxes}
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == '__main__':
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue