yolo_box OP add Attr(clip_bbox). (#21620)

* yolo_box OP add Attr(clip_bbox). test=develop
paddle_tiny_install
Kaipeng Deng 5 years ago committed by GitHub
parent c4f8f3bddc
commit 943a44492b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,9 +43,12 @@ class YoloBoxOp : public framework::OperatorWithKernel {
"+ class_num))."); "+ class_num)).");
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2, PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
"Input(ImgSize) should be a 2-D tensor."); "Input(ImgSize) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ( if ((dim_imgsize[0] > 0 && dim_x[0] > 0) || ctx->IsRuntime()) {
dim_imgsize[0], dim_x[0], PADDLE_ENFORCE_EQ(
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."); dim_imgsize[0], dim_x[0],
platform::errors::InvalidArgument(
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."));
}
PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2."); PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2.");
PADDLE_ENFORCE_GT(anchors.size(), 0, PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater than 0."); "Attr(anchors) length should be greater than 0.");
@ -110,6 +113,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Boxes with confidence scores under threshold should " "Boxes with confidence scores under threshold should "
"be ignored.") "be ignored.")
.SetDefault(0.01); .SetDefault(0.01);
AddAttr<bool>("clip_bbox",
"Whether clip output bonding box in Input(ImgSize) "
"boundary. Default true.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network. This operator generates YOLO detection boxes from output of YOLOv3 network.

@ -26,7 +26,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh, T* scores, const float conf_thresh,
const int* anchors, const int n, const int h, const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num, const int w, const int an_num, const int class_num,
const int box_num, int input_size) { const int box_num, int input_size, bool clip_bbox) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
T box[4]; T box[4];
@ -53,7 +53,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx, GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width); grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4; box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width); CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
int label_idx = int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
@ -76,6 +76,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh"); float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
const int n = input->dims()[0]; const int n = input->dims()[0];
const int h = input->dims()[2]; const int h = input->dims()[2];
@ -107,7 +108,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh, input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size); anchors_data, n, h, w, an_num, class_num, box_num, input_size,
clip_bbox);
} }
}; };

@ -47,21 +47,23 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
template <typename T> template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int img_height, const int img_height,
const int img_width) { const int img_width, bool clip_bbox) {
boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0); if (clip_bbox) {
boxes[box_idx + 1] = boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0); boxes[box_idx + 1] =
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
? boxes[box_idx + 2] boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
: static_cast<T>(img_width - 1); ? boxes[box_idx + 2]
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 : static_cast<T>(img_width - 1);
? boxes[box_idx + 3] boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
: static_cast<T>(img_height - 1); ? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
} }
template <typename T> template <typename T>
@ -86,6 +88,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int class_num = ctx.Attr<int>("class_num"); int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh"); float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio"); int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
const int n = input->dims()[0]; const int n = input->dims()[0];
const int h = input->dims()[2]; const int h = input->dims()[2];
@ -130,8 +133,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size, GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width); box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4; box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
img_width); clip_bbox);
int label_idx = int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);

@ -1023,6 +1023,7 @@ def yolo_box(x,
class_num, class_num,
conf_thresh, conf_thresh,
downsample_ratio, downsample_ratio,
clip_bbox=True,
name=None): name=None):
""" """
${comment} ${comment}
@ -1034,6 +1035,7 @@ def yolo_box(x,
class_num (int): ${class_num_comment} class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment} conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment} downsample_ratio (int): ${downsample_ratio_comment}
clip_bbox (bool): ${clip_bbox_comment}
name (string): The default value is None. Normally there is no need name (string): The default value is None. Normally there is no need
for user to set this property. For more information, for user to set this property. For more information,
please refer to :ref:`api_guide_Name` please refer to :ref:`api_guide_Name`
@ -1081,6 +1083,7 @@ def yolo_box(x,
"class_num": class_num, "class_num": class_num,
"conf_thresh": conf_thresh, "conf_thresh": conf_thresh,
"downsample_ratio": downsample_ratio, "downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox,
} }
helper.append_op( helper.append_op(

@ -32,6 +32,7 @@ def YoloBox(x, img_size, attrs):
class_num = attrs['class_num'] class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh'] conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample'] downsample = attrs['downsample']
clip_bbox = attrs['clip_bbox']
input_size = downsample * h input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
@ -64,13 +65,14 @@ def YoloBox(x, img_size, attrs):
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis] pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis] pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
for i in range(len(pred_box)): if clip_bbox:
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf) for i in range(len(pred_box)):
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf) pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf, pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
img_size[i, 1] - 1) pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf, img_size[i, 1] - 1)
img_size[i, 0] - 1) pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
return pred_box, pred_score.reshape((n, -1, class_num)) return pred_box, pred_score.reshape((n, -1, class_num))
@ -87,6 +89,7 @@ class TestYoloBoxOp(OpTest):
"class_num": self.class_num, "class_num": self.class_num,
"conf_thresh": self.conf_thresh, "conf_thresh": self.conf_thresh,
"downsample": self.downsample, "downsample": self.downsample,
"clip_bbox": self.clip_bbox,
} }
self.inputs = { self.inputs = {
@ -109,6 +112,20 @@ class TestYoloBoxOp(OpTest):
self.class_num = 2 self.class_num = 2
self.conf_thresh = 0.5 self.conf_thresh = 0.5
self.downsample = 32 self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2) self.imgsize_shape = (self.batch_size, 2)

Loading…
Cancel
Save