update iou_sim code

fix-profile-doc-typo
wanghaox 7 years ago
parent 3b63815629
commit 528bcac52c

@ -23,12 +23,16 @@ class IOUSimilarityOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of IOUSimilarityOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of IOUSimilarityOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The shape of X is [N, 4]");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The rank of Input(X) must be 2.");
PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]");
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The shape of Y is [M, 4]");
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The rank of Input(Y) must be 2.");
PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]");
ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]}));
@ -39,16 +43,18 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor, default Tensor<float>) "
"BoxList X holding N boxes, each box is "
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4].");
AddInput(
"Y",
"(Tensor, default Tensor<float>) "
"BoxList Y holding M boxes, each box is "
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4].");
AddInput("X",
"(Tensor, default Tensor<float>) "
"Box list X holds N boxes, each box is "
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, "
"4]. [xmin, ymin] is the lower left coordinate of the box, and "
"[xmax, ymax] is the right upper coordinate of the box.");
AddInput("Y",
"(Tensor, default Tensor<float>) "
"Box list Y holds M boxes, each box is "
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, "
"4]. [xmin, ymin] is the lower left coordinate of the box, and "
"[xmax, ymax] is the right upper coordinate of the box.");
AddOutput(
"Out",
@ -57,7 +63,7 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
IOU Similarity Operator.
Computes pairwise intersection-over-union between box collections.
Computes intersection-over-union (IOU) between two box lists.
)DOC");
}
};

@ -0,0 +1,21 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/iou_similarity_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
iou_similarity,
ops::IOUSimilarityKernel<paddle::platform::CUDADeviceContext, float>);

@ -17,16 +17,19 @@ limitations under the License. */
#include "paddle/platform/for_range.h"
template <typename T>
inline T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, T ymin2,
T xmax2, T ymax2) {
inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2,
T ymin2, T xmax2, T ymax2) {
constexpr T zero = static_cast<T>(0);
T area1 = (ymax1 - ymin1) * (xmax1 - xmin1);
T area2 = (ymax2 - ymin2) * (xmax2 - xmin2);
T inter_xmax = std::min(xmax1, xmax2);
T inter_ymax = std::min(ymax1, ymax2);
T inter_xmin = std::max(xmin1, xmin2);
T inter_ymin = std::max(ymin1, ymin2);
T inter_height = std::max(inter_ymax - inter_ymin, static_cast<T>(0));
T inter_width = std::max(inter_xmax - inter_xmin, static_cast<T>(0));
T inter_xmax = xmax1 > xmax2 ? xmax2 : xmax1;
T inter_ymax = ymax1 > ymax2 ? ymax2 : ymax1;
T inter_xmin = xmin1 > xmin2 ? xmin1 : xmin2;
T inter_ymin = ymin1 > ymin2 ? ymin1 : ymin2;
T inter_height = inter_ymax - inter_ymin;
T inter_width = inter_xmax - inter_xmin;
inter_height = inter_height > zero ? inter_height : zero;
inter_width = inter_width > zero ? inter_width : zero;
T inter_area = inter_width * inter_height;
T union_area = area1 + area2 - inter_area;
T sim_score = inter_area / union_area;

@ -1,3 +1,16 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import unittest
import numpy as np
import sys
@ -6,23 +19,11 @@ from op_test import OpTest
class TestIOUSimilarityOp(OpTest):
def set_data(self):
self.init_test_data()
self.inputs = {'X': self.boxes1, 'Y': self.boxes2}
self.outputs = {'Out': self.output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
return
def setUp(self):
self.op_type = "iou_similarity"
self.set_data()
def init_test_data(self):
self.boxes1 = np.array(
[[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]).astype('float32')
self.boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
@ -30,6 +31,12 @@ class TestIOUSimilarityOp(OpTest):
self.output = np.array(
[[2.0 / 16.0, 0, 6.0 / 400.0],
[1.0 / 16.0, 0.0, 5.0 / 400.0]]).astype('float32')
# self.output = np.array([[0, 0, 0],
# [0, 0, 0]]).astype('float32')
self.inputs = {'X': self.boxes1, 'Y': self.boxes2}
self.outputs = {'Out': self.output}
if __name__ == '__main__':

Loading…
Cancel
Save