|
|
|
@ -20,7 +20,6 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
HOSTDEVICE inline T sigmoid(T x) {
|
|
|
|
|
return 1.0 / (1.0 + std::exp(-x));
|
|
|
|
@ -28,15 +27,15 @@ HOSTDEVICE inline T sigmoid(T x) {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
|
|
|
|
|
int j, int an_idx, int grid_size,
|
|
|
|
|
int input_size, int index, int stride,
|
|
|
|
|
int img_height, int img_width) {
|
|
|
|
|
int j, int an_idx, int grid_size,
|
|
|
|
|
int input_size, int index, int stride,
|
|
|
|
|
int img_height, int img_width) {
|
|
|
|
|
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
|
|
|
|
|
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
|
|
|
|
|
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
|
|
|
|
|
input_size;
|
|
|
|
|
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height /
|
|
|
|
|
input_size;
|
|
|
|
|
input_size;
|
|
|
|
|
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
|
|
|
|
|
img_height / input_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
|
|
|
|
@ -47,16 +46,22 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
|
|
|
|
|
const int img_height, const int img_width) {
|
|
|
|
|
const int img_height,
|
|
|
|
|
const int img_width) {
|
|
|
|
|
boxes[box_idx] = box[0] - box[2] / 2;
|
|
|
|
|
boxes[box_idx + 1] = box[1] - box[3] / 2;
|
|
|
|
|
boxes[box_idx + 2] = box[0] + box[2] / 2;
|
|
|
|
|
boxes[box_idx + 3] = box[1] + box[3] / 2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
|
|
|
|
|
boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
|
|
|
|
|
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast<T>(img_width - 1);
|
|
|
|
|
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast<T>(img_height - 1);
|
|
|
|
|
boxes[box_idx + 1] =
|
|
|
|
|
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
|
|
|
|
|
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
|
|
|
|
|
? boxes[box_idx + 2]
|
|
|
|
|
: static_cast<T>(img_width - 1);
|
|
|
|
|
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
|
|
|
|
|
? boxes[box_idx + 3]
|
|
|
|
|
: static_cast<T>(img_height - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -92,8 +97,10 @@ class YoloBoxKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int stride = h * w;
|
|
|
|
|
const int an_stride = (class_num + 5) * stride;
|
|
|
|
|
|
|
|
|
|
int anchors_[anchors.size()];
|
|
|
|
|
std::copy(anchors.begin(), anchors.end(), anchors_);
|
|
|
|
|
Tensor anchors_;
|
|
|
|
|
auto anchors_data =
|
|
|
|
|
anchors_.mutable_data<int>({an_num * 2}, ctx.GetPlace());
|
|
|
|
|
std::copy(anchors.begin(), anchors.end(), anchors_data);
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
const int* imgsize_data = imgsize->data<int>();
|
|
|
|
@ -120,10 +127,11 @@ class YoloBoxKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int box_idx =
|
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
|
|
|
|
|
GetYoloBox<T>(box, input_data, anchors_, l, k, j, h, input_size,
|
|
|
|
|
box_idx, stride, img_height, img_width);
|
|
|
|
|
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
|
|
|
|
|
box_idx, stride, img_height, img_width);
|
|
|
|
|
box_idx = (i * box_num + j * stride + k * w + l) * 4;
|
|
|
|
|
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width);
|
|
|
|
|
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height,
|
|
|
|
|
img_width);
|
|
|
|
|
|
|
|
|
|
int label_idx =
|
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
|
|
|
|
|