|
|
|
@ -26,8 +26,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
|
T* scores, const float conf_thresh,
|
|
|
|
|
const int* anchors, const int n, const int h,
|
|
|
|
|
const int w, const int an_num, const int class_num,
|
|
|
|
|
const int box_num, int input_size, bool clip_bbox,
|
|
|
|
|
const float scale, const float bias) {
|
|
|
|
|
const int box_num, int input_size_h,
|
|
|
|
|
int input_size_w, bool clip_bbox, const float scale,
|
|
|
|
|
const float bias) {
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
T box[4];
|
|
|
|
@ -51,8 +52,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
|
|
|
|
|
|
int box_idx =
|
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
|
|
|
|
|
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
|
|
|
|
|
grid_num, img_height, img_width, scale, bias);
|
|
|
|
|
GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
|
|
|
|
|
input_size_w, box_idx, grid_num, img_height, img_width, scale,
|
|
|
|
|
bias);
|
|
|
|
|
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
|
|
|
|
|
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
|
|
|
|
|
|
|
|
|
@ -86,7 +88,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int w = input->dims()[3];
|
|
|
|
|
const int box_num = boxes->dims()[1];
|
|
|
|
|
const int an_num = anchors.size() / 2;
|
|
|
|
|
int input_size = downsample_ratio * h;
|
|
|
|
|
int input_size_h = downsample_ratio * h;
|
|
|
|
|
int input_size_w = downsample_ratio * w;
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.cuda_device_context();
|
|
|
|
|
int bytes = sizeof(int) * anchors.size();
|
|
|
|
@ -111,8 +114,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
|
|
|
|
|
anchors_data, n, h, w, an_num, class_num, box_num, input_size,
|
|
|
|
|
clip_bbox, scale, bias);
|
|
|
|
|
anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
|
|
|
|
|
input_size_w, clip_bbox, scale, bias);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|