parent
cb6b468e35
commit
2ad5a6f0d1
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/iou_similarity_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class IOUSimilarityOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
auto y_dims = ctx->GetInputDim("Y");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The shape of X is [N, 4]");
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]");
|
||||||
|
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The shape of Y is [M, 4]");
|
||||||
|
PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]");
|
||||||
|
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput(
|
||||||
|
"X",
|
||||||
|
"(Tensor, default Tensor<float>) "
|
||||||
|
"BoxList X holding N boxes, each box is "
|
||||||
|
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4].");
|
||||||
|
AddInput(
|
||||||
|
"Y",
|
||||||
|
"(Tensor, default Tensor<float>) "
|
||||||
|
"BoxList Y holding M boxes, each box is "
|
||||||
|
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4].");
|
||||||
|
|
||||||
|
AddOutput(
|
||||||
|
"Out",
|
||||||
|
"(Tensor) The output of iou_similarity op, a tensor with shape [N, M] "
|
||||||
|
"representing pairwise iou scores.");
|
||||||
|
|
||||||
|
AddComment(R"DOC(
|
||||||
|
IOU Similarity Operator.
|
||||||
|
Computes pairwise intersection-over-union between box collections.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp,
|
||||||
|
ops::IOUSimilarityOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
iou_similarity,
|
||||||
|
ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::IOUSimilarityKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,87 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/platform/for_range.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, T ymin2,
|
||||||
|
T xmax2, T ymax2) {
|
||||||
|
T area1 = (ymax1 - ymin1) * (xmax1 - xmin1);
|
||||||
|
T area2 = (ymax2 - ymin2) * (xmax2 - xmin2);
|
||||||
|
T inter_xmax = std::min(xmax1, xmax2);
|
||||||
|
T inter_ymax = std::min(ymax1, ymax2);
|
||||||
|
T inter_xmin = std::max(xmin1, xmin2);
|
||||||
|
T inter_ymin = std::max(ymin1, ymin2);
|
||||||
|
T inter_height = std::max(inter_ymax - inter_ymin, static_cast<T>(0));
|
||||||
|
T inter_width = std::max(inter_xmax - inter_xmin, static_cast<T>(0));
|
||||||
|
T inter_area = inter_width * inter_height;
|
||||||
|
T union_area = area1 + area2 - inter_area;
|
||||||
|
T sim_score = inter_area / union_area;
|
||||||
|
return sim_score;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct IOUSimilarityFunctor {
|
||||||
|
IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols)
|
||||||
|
: x_(x), y_(y), z_(z), cols_(static_cast<size_t>(cols)) {}
|
||||||
|
|
||||||
|
inline HOSTDEVICE void operator()(size_t row_id) const {
|
||||||
|
T x_min1 = x_[row_id * 4];
|
||||||
|
T y_min1 = x_[row_id * 4 + 1];
|
||||||
|
T x_max1 = x_[row_id * 4 + 2];
|
||||||
|
T y_max1 = x_[row_id * 4 + 3];
|
||||||
|
for (int i = 0; i < cols_; ++i) {
|
||||||
|
T x_min2 = y_[i * 4];
|
||||||
|
T y_min2 = y_[i * 4 + 1];
|
||||||
|
T x_max2 = y_[i * 4 + 2];
|
||||||
|
T y_max2 = y_[i * 4 + 3];
|
||||||
|
|
||||||
|
T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2,
|
||||||
|
x_max2, y_max2);
|
||||||
|
|
||||||
|
z_[row_id * cols_ + i] = sim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const T* x_;
|
||||||
|
const T* y_;
|
||||||
|
T* z_;
|
||||||
|
const size_t cols_;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class IOUSimilarityKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
const framework::Tensor* in_x = ctx.Input<framework::Tensor>("X");
|
||||||
|
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
|
||||||
|
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
||||||
|
|
||||||
|
int x_n = in_x->dims()[0];
|
||||||
|
int y_n = in_y->dims()[0];
|
||||||
|
IOUSimilarityFunctor<T> functor(in_x->data<T>(), in_y->data<T>(),
|
||||||
|
out->mutable_data<T>(ctx.GetPlace()), y_n);
|
||||||
|
|
||||||
|
platform::ForRange<DeviceContext> for_range(
|
||||||
|
static_cast<const DeviceContext&>(ctx.device_context()), x_n);
|
||||||
|
for_range(functor);
|
||||||
|
}
|
||||||
|
}; // namespace operators
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,36 @@
|
|||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestIOUSimilarityOp(OpTest):
|
||||||
|
def set_data(self):
|
||||||
|
self.init_test_data()
|
||||||
|
self.inputs = {'X': self.boxes1, 'Y': self.boxes2}
|
||||||
|
|
||||||
|
self.outputs = {'Out': self.output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "iou_similarity"
|
||||||
|
self.set_data()
|
||||||
|
|
||||||
|
def init_test_data(self):
|
||||||
|
self.boxes1 = np.array(
|
||||||
|
[[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]).astype('float32')
|
||||||
|
self.boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
|
||||||
|
[0.0, 0.0, 20.0, 20.0]]).astype('float32')
|
||||||
|
self.output = np.array(
|
||||||
|
[[2.0 / 16.0, 0, 6.0 / 400.0],
|
||||||
|
[1.0 / 16.0, 0.0, 5.0 / 400.0]]).astype('float32')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue