commit
d43932c846
@ -0,0 +1,96 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/iou_similarity_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class IOUSimilarityOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of IOUSimilarityOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
||||
"Input(Y) of IOUSimilarityOp should not be null.");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The rank of Input(X) must be 2.");
|
||||
PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]");
|
||||
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The rank of Input(Y) must be 2.");
|
||||
PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]");
|
||||
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]}));
|
||||
}
|
||||
};
|
||||
|
||||
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(LoDTensor, default LoDTensor<float>) "
|
||||
"Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
|
||||
"each box is represented as [xmin, ymin, xmax, ymax], "
|
||||
"the shape of X is [N, 4]. [xmin, ymin] is the left top "
|
||||
"coordinate of the box if the input is image feature map, they "
|
||||
"are close to the origin of the coordinate system. "
|
||||
"[xmax, ymax] is the right bottom coordinate of the box. "
|
||||
"This tensor can contain LoD information to represent a batch "
|
||||
"of inputs. One instance of this batch can contain different "
|
||||
"numbers of entities.");
|
||||
AddInput("Y",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"Box list Y holds M boxes, each box is represented as "
|
||||
"[xmin, ymin, xmax, ymax], the shape of X is [N, 4]. "
|
||||
"[xmin, ymin] is the left top coordinate of the box if the "
|
||||
"input is image feature map, and [xmax, ymax] is the right "
|
||||
"bottom coordinate of the box.");
|
||||
|
||||
AddOutput("Out",
|
||||
"(LoDTensor, the lod is same as input X) The output of "
|
||||
"iou_similarity op, a tensor with shape [N, M] "
|
||||
"representing pairwise iou scores.");
|
||||
|
||||
AddComment(R"DOC(
|
||||
IOU Similarity Operator.
|
||||
Computes intersection-over-union (IOU) between two box lists.
|
||||
Box list 'X' should be a LoDTensor and 'Y' is a common Tensor,
|
||||
boxes in 'Y' are shared by all instance of the batched inputs of X.
|
||||
Given two boxes A and B, the calculation of IOU is as follows:
|
||||
|
||||
$$
|
||||
IOU(A, B) =
|
||||
\frac{area(A\cap B)}{area(A)+area(B)-area(A\cap B)}
|
||||
$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp,
|
||||
ops::IOUSimilarityOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
iou_similarity,
|
||||
ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/iou_similarity_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
iou_similarity,
|
||||
ops::IOUSimilarityKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::IOUSimilarityKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,90 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/for_range.h"
|
||||
|
||||
template <typename T>
|
||||
inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2,
|
||||
T ymin2, T xmax2, T ymax2) {
|
||||
constexpr T zero = static_cast<T>(0);
|
||||
T area1 = (ymax1 - ymin1) * (xmax1 - xmin1);
|
||||
T area2 = (ymax2 - ymin2) * (xmax2 - xmin2);
|
||||
T inter_xmax = xmax1 > xmax2 ? xmax2 : xmax1;
|
||||
T inter_ymax = ymax1 > ymax2 ? ymax2 : ymax1;
|
||||
T inter_xmin = xmin1 > xmin2 ? xmin1 : xmin2;
|
||||
T inter_ymin = ymin1 > ymin2 ? ymin1 : ymin2;
|
||||
T inter_height = inter_ymax - inter_ymin;
|
||||
T inter_width = inter_xmax - inter_xmin;
|
||||
inter_height = inter_height > zero ? inter_height : zero;
|
||||
inter_width = inter_width > zero ? inter_width : zero;
|
||||
T inter_area = inter_width * inter_height;
|
||||
T union_area = area1 + area2 - inter_area;
|
||||
T sim_score = inter_area / union_area;
|
||||
return sim_score;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct IOUSimilarityFunctor {
|
||||
IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols)
|
||||
: x_(x), y_(y), z_(z), cols_(static_cast<size_t>(cols)) {}
|
||||
|
||||
inline HOSTDEVICE void operator()(size_t row_id) const {
|
||||
T x_min1 = x_[row_id * 4];
|
||||
T y_min1 = x_[row_id * 4 + 1];
|
||||
T x_max1 = x_[row_id * 4 + 2];
|
||||
T y_max1 = x_[row_id * 4 + 3];
|
||||
for (size_t i = 0; i < cols_; ++i) {
|
||||
T x_min2 = y_[i * 4];
|
||||
T y_min2 = y_[i * 4 + 1];
|
||||
T x_max2 = y_[i * 4 + 2];
|
||||
T y_max2 = y_[i * 4 + 3];
|
||||
|
||||
T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2,
|
||||
x_max2, y_max2);
|
||||
|
||||
z_[row_id * cols_ + i] = sim;
|
||||
}
|
||||
}
|
||||
const T* x_;
|
||||
const T* y_;
|
||||
T* z_;
|
||||
const size_t cols_;
|
||||
};
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class IOUSimilarityKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const framework::LoDTensor* in_x = ctx.Input<framework::LoDTensor>("X");
|
||||
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
|
||||
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
|
||||
|
||||
int x_n = in_x->dims()[0];
|
||||
int y_n = in_y->dims()[0];
|
||||
IOUSimilarityFunctor<T> functor(in_x->data<T>(), in_y->data<T>(),
|
||||
out->mutable_data<T>(ctx.GetPlace()), y_n);
|
||||
|
||||
platform::ForRange<DeviceContext> for_range(
|
||||
static_cast<const DeviceContext&>(ctx.device_context()), x_n);
|
||||
for_range(functor);
|
||||
}
|
||||
}; // namespace operators
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestIOUSimilarityOp(OpTest):
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "iou_similarity"
|
||||
self.boxes1 = np.array(
|
||||
[[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]).astype('float32')
|
||||
self.boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
|
||||
[0.0, 0.0, 20.0, 20.0]]).astype('float32')
|
||||
self.output = np.array(
|
||||
[[2.0 / 16.0, 0, 6.0 / 400.0],
|
||||
[1.0 / 16.0, 0.0, 5.0 / 400.0]]).astype('float32')
|
||||
|
||||
self.inputs = {'X': self.boxes1, 'Y': self.boxes2}
|
||||
|
||||
self.outputs = {'Out': self.output}
|
||||
|
||||
|
||||
class TestIOUSimilarityOpWithLoD(TestIOUSimilarityOp):
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
super(TestIOUSimilarityOpWithLoD, self).setUp()
|
||||
self.boxes1_lod = [[0, 1, 2]]
|
||||
self.output_lod = [[0, 1, 2]]
|
||||
|
||||
self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2}
|
||||
self.outputs = {'Out': (self.output, self.output_lod)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue