Merge pull request #15919 from heavengate/yolo_box
add yolo_box for detection box calc in YOLOv3revert-16190-refine_parallel_executor
commit
b77ebb2af2
@ -0,0 +1,167 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/yolo_box_op.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class YoloBoxOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of YoloBoxOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("ImgSize"),
|
||||
"Input(ImgSize) of YoloBoxOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Boxes"),
|
||||
"Output(Boxes) of YoloBoxOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Scores"),
|
||||
"Output(Scores) of YoloBoxOp should not be null.");
|
||||
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
auto dim_imgsize = ctx->GetInputDim("ImgSize");
|
||||
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
|
||||
int anchor_num = anchors.size() / 2;
|
||||
auto class_num = ctx->Attrs().Get<int>("class_num");
|
||||
|
||||
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim_x[1], anchor_num * (5 + class_num),
|
||||
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
|
||||
"+ class_num)).");
|
||||
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
|
||||
"Input(ImgSize) should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim_imgsize[0], dim_x[0],
|
||||
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same.");
|
||||
PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2.");
|
||||
PADDLE_ENFORCE_GT(anchors.size(), 0,
|
||||
"Attr(anchors) length should be greater than 0.");
|
||||
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
|
||||
"Attr(anchors) length should be even integer.");
|
||||
PADDLE_ENFORCE_GT(class_num, 0,
|
||||
"Attr(class_num) should be an integer greater than 0.");
|
||||
|
||||
int box_num = dim_x[2] * dim_x[3] * anchor_num;
|
||||
std::vector<int64_t> dim_boxes({dim_x[0], box_num, 4});
|
||||
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_boxes));
|
||||
|
||||
std::vector<int64_t> dim_scores({dim_x[0], box_num, class_num});
|
||||
ctx->SetOutputDim("Scores", framework::make_ddim(dim_scores));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"The input tensor of YoloBox operator is a 4-D tensor with "
|
||||
"shape of [N, C, H, W]. The second dimension(C) stores "
|
||||
"box locations, confidence score and classification one-hot "
|
||||
"keys of each anchor box. Generally, X should be the output "
|
||||
"of YOLOv3 network.");
|
||||
AddInput("ImgSize",
|
||||
"The image size tensor of YoloBox operator, "
|
||||
"This is a 2-D tensor with shape of [N, 2]. This tensor holds "
|
||||
"height and width of each input image used for resizing output "
|
||||
"box in input image scale.");
|
||||
AddOutput("Boxes",
|
||||
"The output tensor of detection boxes of YoloBox operator, "
|
||||
"This is a 3-D tensor with shape of [N, M, 4], N is the "
|
||||
"batch num, M is output box number, and the 3rd dimension "
|
||||
"stores [xmin, ymin, xmax, ymax] coordinates of boxes.");
|
||||
AddOutput("Scores",
|
||||
"The output tensor of detection boxes scores of YoloBox "
|
||||
"operator, This is a 3-D tensor with shape of "
|
||||
"[N, M, :attr:`class_num`], N is the batch num, M is "
|
||||
"output box number.");
|
||||
|
||||
AddAttr<int>("class_num", "The number of classes to predict.");
|
||||
AddAttr<std::vector<int>>("anchors",
|
||||
"The anchor width and height, "
|
||||
"it will be parsed pair by pair.")
|
||||
.SetDefault(std::vector<int>{});
|
||||
AddAttr<int>("downsample_ratio",
|
||||
"The downsample ratio from network input to YoloBox operator "
|
||||
"input, so 32, 16, 8 should be set for the first, second, "
|
||||
"and thrid YoloBox operators.")
|
||||
.SetDefault(32);
|
||||
AddAttr<float>("conf_thresh",
|
||||
"The confidence scores threshold of detection boxes. "
|
||||
"Boxes with confidence scores under threshold should "
|
||||
"be ignored.")
|
||||
.SetDefault(0.01);
|
||||
AddComment(R"DOC(
|
||||
This operator generates YOLO detection boxes from output of YOLOv3 network.
|
||||
|
||||
The output of previous network is in shape [N, C, H, W], while H and W
|
||||
should be the same, H and W specify the grid size, each grid point predict
|
||||
given number boxes, this given number, which following will be represented as S,
|
||||
is specified by the number of anchors. In the second dimension(the channel
|
||||
dimension), C should be equal to S * (5 + class_num), class_num is the object
|
||||
category number of source dataset(such as 80 in coco dataset), so the
|
||||
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
|
||||
also includes confidence score of the box and class one-hot key of each anchor
|
||||
box.
|
||||
|
||||
Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box
|
||||
predictions should be as follows:
|
||||
|
||||
$$
|
||||
b_x = \\sigma(t_x) + c_x
|
||||
$$
|
||||
$$
|
||||
b_y = \\sigma(t_y) + c_y
|
||||
$$
|
||||
$$
|
||||
b_w = p_w e^{t_w}
|
||||
$$
|
||||
$$
|
||||
b_h = p_h e^{t_h}
|
||||
$$
|
||||
|
||||
in the equation above, :math:`c_x, c_y` is the left top corner of current grid
|
||||
and :math:`p_w, p_h` is specified by anchors.
|
||||
|
||||
The logistic regression value of the 5th channel of each anchor prediction boxes
|
||||
represents the confidence score of each prediction box, and the logistic
|
||||
regression value of the last :attr:`class_num` channels of each anchor prediction
|
||||
boxes represents the classifcation scores. Boxes with confidence scores less than
|
||||
:attr:`conf_thresh` should be ignored, and box final scores is the product of
|
||||
confidence scores and classification scores.
|
||||
|
||||
$$
|
||||
score_{pred} = score_{conf} * score_{class}
|
||||
$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
|
||||
ops::YoloBoxKernel<double>);
|
@ -0,0 +1,120 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/yolo_box_op.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
||||
T* scores, const float conf_thresh,
|
||||
const int* anchors, const int n, const int h,
|
||||
const int w, const int an_num, const int class_num,
|
||||
const int box_num, int input_size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
T box[4];
|
||||
for (; tid < n * box_num; tid += stride) {
|
||||
int grid_num = h * w;
|
||||
int i = tid / box_num;
|
||||
int j = (tid % box_num) / grid_num;
|
||||
int k = (tid % grid_num) / w;
|
||||
int l = tid % w;
|
||||
|
||||
int an_stride = (5 + class_num) * grid_num;
|
||||
int img_height = imgsize[2 * i];
|
||||
int img_width = imgsize[2 * i + 1];
|
||||
|
||||
int obj_idx =
|
||||
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
|
||||
T conf = sigmoid<T>(input[obj_idx]);
|
||||
if (conf < conf_thresh) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int box_idx =
|
||||
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
|
||||
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
|
||||
grid_num, img_height, img_width);
|
||||
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
|
||||
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
|
||||
|
||||
int label_idx =
|
||||
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
|
||||
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
|
||||
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
|
||||
grid_num);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* img_size = ctx.Input<Tensor>("ImgSize");
|
||||
auto* boxes = ctx.Output<Tensor>("Boxes");
|
||||
auto* scores = ctx.Output<Tensor>("Scores");
|
||||
|
||||
auto anchors = ctx.Attr<std::vector<int>>("anchors");
|
||||
int class_num = ctx.Attr<int>("class_num");
|
||||
float conf_thresh = ctx.Attr<float>("conf_thresh");
|
||||
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
|
||||
|
||||
const int n = input->dims()[0];
|
||||
const int h = input->dims()[2];
|
||||
const int w = input->dims()[3];
|
||||
const int box_num = boxes->dims()[1];
|
||||
const int an_num = anchors.size() / 2;
|
||||
int input_size = downsample_ratio * h;
|
||||
|
||||
auto& dev_ctx = ctx.cuda_device_context();
|
||||
auto& allocator =
|
||||
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
|
||||
int bytes = sizeof(int) * anchors.size();
|
||||
auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size());
|
||||
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
|
||||
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
|
||||
const auto cplace = platform::CPUPlace();
|
||||
memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes,
|
||||
dev_ctx.stream());
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
const int* imgsize_data = img_size->data<int>();
|
||||
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
|
||||
T* scores_data =
|
||||
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
|
||||
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
|
||||
set_zero(dev_ctx, boxes, static_cast<T>(0));
|
||||
set_zero(dev_ctx, scores, static_cast<T>(0));
|
||||
|
||||
int grid_dim = (n * box_num + 512 - 1) / 512;
|
||||
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
||||
|
||||
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
||||
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
|
||||
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel<float>,
|
||||
ops::YoloBoxOpCUDAKernel<double>);
|
@ -0,0 +1,149 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
HOSTDEVICE inline T sigmoid(T x) {
|
||||
return 1.0 / (1.0 + std::exp(-x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
|
||||
int j, int an_idx, int grid_size,
|
||||
int input_size, int index, int stride,
|
||||
int img_height, int img_width) {
|
||||
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
|
||||
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
|
||||
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
|
||||
input_size;
|
||||
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
|
||||
img_height / input_size;
|
||||
}
|
||||
|
||||
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
|
||||
int an_num, int an_stride, int stride,
|
||||
int entry) {
|
||||
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
|
||||
const int img_height,
|
||||
const int img_width) {
|
||||
boxes[box_idx] = box[0] - box[2] / 2;
|
||||
boxes[box_idx + 1] = box[1] - box[3] / 2;
|
||||
boxes[box_idx + 2] = box[0] + box[2] / 2;
|
||||
boxes[box_idx + 3] = box[1] + box[3] / 2;
|
||||
|
||||
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
|
||||
boxes[box_idx + 1] =
|
||||
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
|
||||
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
|
||||
? boxes[box_idx + 2]
|
||||
: static_cast<T>(img_width - 1);
|
||||
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
|
||||
? boxes[box_idx + 3]
|
||||
: static_cast<T>(img_height - 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input,
|
||||
const int label_idx, const int score_idx,
|
||||
const int class_num, const T conf,
|
||||
const int stride) {
|
||||
for (int i = 0; i < class_num; i++) {
|
||||
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class YoloBoxKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* imgsize = ctx.Input<Tensor>("ImgSize");
|
||||
auto* boxes = ctx.Output<Tensor>("Boxes");
|
||||
auto* scores = ctx.Output<Tensor>("Scores");
|
||||
auto anchors = ctx.Attr<std::vector<int>>("anchors");
|
||||
int class_num = ctx.Attr<int>("class_num");
|
||||
float conf_thresh = ctx.Attr<float>("conf_thresh");
|
||||
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
|
||||
|
||||
const int n = input->dims()[0];
|
||||
const int h = input->dims()[2];
|
||||
const int w = input->dims()[3];
|
||||
const int box_num = boxes->dims()[1];
|
||||
const int an_num = anchors.size() / 2;
|
||||
int input_size = downsample_ratio * h;
|
||||
|
||||
const int stride = h * w;
|
||||
const int an_stride = (class_num + 5) * stride;
|
||||
|
||||
Tensor anchors_;
|
||||
auto anchors_data =
|
||||
anchors_.mutable_data<int>({an_num * 2}, ctx.GetPlace());
|
||||
std::copy(anchors.begin(), anchors.end(), anchors_data);
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
const int* imgsize_data = imgsize->data<int>();
|
||||
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
|
||||
memset(boxes_data, 0, boxes->numel() * sizeof(T));
|
||||
T* scores_data =
|
||||
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
|
||||
memset(scores_data, 0, scores->numel() * sizeof(T));
|
||||
|
||||
T box[4];
|
||||
for (int i = 0; i < n; i++) {
|
||||
int img_height = imgsize_data[2 * i];
|
||||
int img_width = imgsize_data[2 * i + 1];
|
||||
|
||||
for (int j = 0; j < an_num; j++) {
|
||||
for (int k = 0; k < h; k++) {
|
||||
for (int l = 0; l < w; l++) {
|
||||
int obj_idx =
|
||||
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4);
|
||||
T conf = sigmoid<T>(input_data[obj_idx]);
|
||||
if (conf < conf_thresh) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int box_idx =
|
||||
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
|
||||
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
|
||||
box_idx, stride, img_height, img_width);
|
||||
box_idx = (i * box_num + j * stride + k * w + l) * 4;
|
||||
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height,
|
||||
img_width);
|
||||
|
||||
int label_idx =
|
||||
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
|
||||
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
|
||||
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
|
||||
class_num, conf, stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,117 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
from paddle.fluid import core
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
return 1.0 / (1.0 + np.exp(-1.0 * x))
|
||||
|
||||
|
||||
def YoloBox(x, img_size, attrs):
|
||||
n, c, h, w = x.shape
|
||||
anchors = attrs['anchors']
|
||||
an_num = int(len(anchors) // 2)
|
||||
class_num = attrs['class_num']
|
||||
conf_thresh = attrs['conf_thresh']
|
||||
downsample = attrs['downsample']
|
||||
input_size = downsample * h
|
||||
|
||||
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
|
||||
|
||||
pred_box = x[:, :, :, :, :4].copy()
|
||||
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
|
||||
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
|
||||
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
|
||||
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
|
||||
|
||||
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
|
||||
anchors_s = np.array(
|
||||
[(an_w / input_size, an_h / input_size) for an_w, an_h in anchors])
|
||||
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
|
||||
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
|
||||
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
|
||||
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
|
||||
|
||||
pred_conf = sigmoid(x[:, :, :, :, 4:5])
|
||||
pred_conf[pred_conf < conf_thresh] = 0.
|
||||
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
|
||||
pred_box = pred_box * (pred_conf > 0.).astype('float32')
|
||||
|
||||
pred_box = pred_box.reshape((n, -1, 4))
|
||||
pred_box[:, :, :2], pred_box[:, :, 2:4] = \
|
||||
pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \
|
||||
pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0
|
||||
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
|
||||
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
|
||||
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
|
||||
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
|
||||
|
||||
for i in range(len(pred_box)):
|
||||
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
|
||||
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
|
||||
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
|
||||
img_size[i, 1] - 1)
|
||||
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
|
||||
img_size[i, 0] - 1)
|
||||
|
||||
return pred_box, pred_score.reshape((n, -1, class_num))
|
||||
|
||||
|
||||
class TestYoloBoxOp(OpTest):
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.op_type = 'yolo_box'
|
||||
x = np.random.random(self.x_shape).astype('float32')
|
||||
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
|
||||
|
||||
self.attrs = {
|
||||
"anchors": self.anchors,
|
||||
"class_num": self.class_num,
|
||||
"conf_thresh": self.conf_thresh,
|
||||
"downsample": self.downsample,
|
||||
}
|
||||
|
||||
self.inputs = {
|
||||
'X': x,
|
||||
'ImgSize': img_size,
|
||||
}
|
||||
boxes, scores = YoloBox(x, img_size, self.attrs)
|
||||
self.outputs = {
|
||||
"Boxes": boxes,
|
||||
"Scores": scores,
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def initTestCase(self):
|
||||
self.anchors = [10, 13, 16, 30, 33, 23]
|
||||
an_num = int(len(self.anchors) // 2)
|
||||
self.batch_size = 32
|
||||
self.class_num = 2
|
||||
self.conf_thresh = 0.5
|
||||
self.downsample = 32
|
||||
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
|
||||
self.imgsize_shape = (self.batch_size, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue