|
|
|
@ -26,7 +26,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
|
T* scores, const float conf_thresh,
|
|
|
|
|
const int* anchors, const int n, const int h,
|
|
|
|
|
const int w, const int an_num, const int class_num,
|
|
|
|
|
const int box_num, int input_size) {
|
|
|
|
|
const int box_num, int input_size, bool clip_bbox) {
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
T box[4];
|
|
|
|
@ -53,7 +53,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
|
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
|
|
|
|
|
grid_num, img_height, img_width);
|
|
|
|
|
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
|
|
|
|
|
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
|
|
|
|
|
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
|
|
|
|
|
|
|
|
|
|
int label_idx =
|
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
|
|
|
|
@ -76,6 +76,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int class_num = ctx.Attr<int>("class_num");
|
|
|
|
|
float conf_thresh = ctx.Attr<float>("conf_thresh");
|
|
|
|
|
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
|
|
|
|
|
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
|
|
|
|
|
|
|
|
|
|
const int n = input->dims()[0];
|
|
|
|
|
const int h = input->dims()[2];
|
|
|
|
@ -107,7 +108,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
|
|
|
|
|
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
|
|
|
|
|
anchors_data, n, h, w, an_num, class_num, box_num, input_size,
|
|
|
|
|
clip_bbox);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|