|
|
|
@ -46,12 +46,17 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box,
|
|
|
|
|
const int box_idx) {
|
|
|
|
|
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
|
|
|
|
|
const int img_height, const int img_width) {
|
|
|
|
|
boxes[box_idx] = box[0] - box[2] / 2;
|
|
|
|
|
boxes[box_idx + 1] = box[1] - box[3] / 2;
|
|
|
|
|
boxes[box_idx + 2] = box[0] + box[2] / 2;
|
|
|
|
|
boxes[box_idx + 3] = box[1] + box[3] / 2;
|
|
|
|
|
|
|
|
|
|
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
|
|
|
|
|
boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
|
|
|
|
|
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast<T>(img_width - 1);
|
|
|
|
|
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast<T>(img_height - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -118,7 +123,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
|
|
|
|
|
GetYoloBox<T>(box, input_data, anchors_, l, k, j, h, input_size,
|
|
|
|
|
box_idx, stride, img_height, img_width);
|
|
|
|
|
box_idx = (i * box_num + j * stride + k * w + l) * 4;
|
|
|
|
|
CalcDetectionBox<T>(boxes_data, box, box_idx);
|
|
|
|
|
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width);
|
|
|
|
|
|
|
|
|
|
int label_idx =
|
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
|
|
|
|
|