reduce based nms final pass - speed improv

refactored faster nms

refactored faster nms + typo fix

added box flipping choice

set choice to true for testing - yz

switching back

new test file
pull/5458/head
danish 5 years ago
parent 0e65b3ba70
commit 7d7fa760a0

@ -36,12 +36,44 @@ __inline__ __device__ void Swap(T *lhs, T *rhs) {
rhs[0] = tmp;
}
// Initialize per row mask array to all true
__global__ void MaskInit(int numSq, bool *row_mask) {
for (int mat_pos = blockIdx.x * blockDim.x + threadIdx.x; mat_pos < numSq; mat_pos += blockDim.x * gridDim.x) {
row_mask[mat_pos] = true;
}
}
// copy data from input to output array sorted by indices returned from bitonic sort
// flips boxes if asked to, default - false -> if (x1/y1 > x2/y2)
template <typename T>
__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_) {
__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_,
bool flip_mode = false) {
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
int correct_index = index_buff[(num - 1) - box_num]; // flip the array around
for (int x = 0; x < 5; x++) {
data_out[(box_num * box_size_) + x] = data_in[(correct_index * box_size_) + x];
int correct_arr_start = correct_index * box_size_;
int current_arr_start = box_num * box_size_;
if (flip_mode) { // flip boxes
// check x
if (data_in[correct_arr_start + 0] > data_in[correct_arr_start + 2]) {
data_out[current_arr_start + 0] = data_in[correct_arr_start + 2];
data_out[current_arr_start + 2] = data_in[correct_arr_start + 0];
} else {
data_out[current_arr_start + 0] = data_in[correct_arr_start + 0];
data_out[current_arr_start + 2] = data_in[correct_arr_start + 2];
}
// check y
if (data_in[correct_arr_start + 1] > data_in[correct_arr_start + 3]) {
data_out[current_arr_start + 1] = data_in[correct_arr_start + 3];
data_out[current_arr_start + 3] = data_in[correct_arr_start + 1];
} else {
data_out[current_arr_start + 1] = data_in[correct_arr_start + 1];
data_out[current_arr_start + 3] = data_in[correct_arr_start + 3];
}
data_out[current_arr_start + 4] = data_in[correct_arr_start + 4];
} else { // default behaviour, don't flip
for (int x = 0; x < 5; x++) {
data_out[current_arr_start + x] = data_in[correct_arr_start + x];
}
}
}
}
@ -57,55 +89,55 @@ __inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, in
T height = max(y_2 - y_1, T(0));
T combined_area = area[box_A_ix] + area[box_B_ix];
// return decision to keep or remove box
return !(((width * height) / (combined_area - (width * height))) > IOU_value);
return !(((width * height) / (combined_area - (width * height))) >= IOU_value);
}
// calculate areas for boxes -> sorted by output boxes
// populated return mask (init to all true) and return index array
template <typename T>
__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) {
__global__ void Preprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *output, int box_size_) {
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
sel_idx[box_num] = box_num;
sel_boxes[box_num] = true;
area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) *
(output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]);
}
}
// Run parallel NMS pass
// Every box updates it's own mask in row_mask in sep threads
template <typename T>
__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes,
int box_size_) {
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
// represents highest score box in that GPU block
if (threadIdx.x == 0) {
sel_boxes[box_num] = true;
continue;
__global__ void NMSPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
bool *row_mask) {
int box_i_start_index, box_j_start_index; // actual input data indexing
int mask_offset = 0;
for (int box_i = blockIdx.x * blockDim.x + threadIdx.x; box_i < num - 1; box_i += blockDim.x * gridDim.x) {
mask_offset = box_i * num;
box_i_start_index = box_i * box_size_; // adjust starting index
for (int box_j = box_i + 1; box_j < num; box_j++) {
box_j_start_index = box_j * box_size_;
row_mask[mask_offset + box_j] =
IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value);
}
int box_start_index = box_num * box_size_; // start index adjustment
int block_max_box_num = ((blockIdx.x * blockDim.x) + 0);
int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment
sel_boxes[box_num] =
IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area,
IOU_value); // update mask
}
}
template <typename T>
__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) {
int box_i, box_j; // access all shared mem meta data with these
int box_i_start_index, box_j_start_index; // actual input data indexing
for (int i = 0; i < num - 1; i++) {
box_i = i;
box_i_start_index = box_i * box_size_; // adjust starting index
if (sel_boxes[box_i]) {
for (int j = i + 1; j < num; j++) {
box_j = j;
box_j_start_index = box_j * box_size_;
if (sel_boxes[box_j]) {
sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value);
}
}
// Reduce pass runs on 1 block to allow thread sync
__global__ void ReducePass(const int num, bool *sel_boxes, bool *row_mask) {
// loop over every box in order of high to low confidence score
for (int i = 0; i < num - 1; ++i) {
if (!sel_boxes[i]) {
continue;
}
// every thread handles a different set of boxes (per all boxes in order)
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num; j += blockDim.x * gridDim.x) {
sel_boxes[j] = sel_boxes[j] && row_mask[i * num + j];
}
__syncthreads(); // sync all threads before moving all active threads to next iteration
}
}
// Sorting function based on BitonicSort from TopK kernel
template <typename T>
__global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input,
T *data_buff, int *index_buff, int box_size_) {
@ -139,41 +171,37 @@ __global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, con
}
template <typename T>
void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_,
cudaStream_t cuda_stream) {
PopulateOutput<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(input, output, index_buff, num, box_size_);
Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, area, output, box_size_);
void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *input, T *output, int *index_buff,
int box_size_, bool *row_mask, cudaStream_t cuda_stream) {
int total_val = num * num;
MaskInit<<<GET_BLOCKS(total_val), GET_THREADS, 0, cuda_stream>>>(total_val, row_mask);
// default for flipping boxes -> false (provision available to flip if API updated)
PopulateOutput<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(input, output, index_buff, num, box_size_, false);
Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, sel_boxes, area, output, box_size_);
}
template <typename T>
void CalSortInit(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
cudaStream_t stream) {
void CalSort(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
cudaStream_t stream) {
int ceil_p_2 = NMSRoundUpPower2(num);
int thread = std::min(ceil_p_2, GET_THREADS);
NMS_BitonicSortByKeyKernel<<<1, thread, 0, stream>>>(1, num, ceil_p_2, data_in, data_buff, index_buff, box_size_);
}
template <typename T>
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream) {
NMSWithMaskKernel<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes,
box_size_);
void CalNMS(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, bool *row_mask,
cudaStream_t cuda_stream) {
NMSPass<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_,
row_mask);
ReducePass<<<1, GET_THREADS, 0, cuda_stream>>>(num, sel_boxes, row_mask);
}
template <typename T>
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream) {
FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_);
}
template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *input, float *output,
int *index_buff, int box_size_, cudaStream_t cuda_stream);
template void CalSortInit<float>(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff,
int box_size_, cudaStream_t stream);
template void CalSort<float>(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff,
int box_size_, cudaStream_t stream);
template void CalNMSWithMask<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
int box_size_, cudaStream_t cuda_stream);
template void CalPreprocess<float>(const int num, int *sel_idx, bool *sel_boxes, float *area, float *input,
float *output, int *index_buff, int box_size_, bool *row_mask,
cudaStream_t cuda_stream);
template void CalFinalPass<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
int box_size_, cudaStream_t cuda_stream);
template void CalNMS<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
int box_size_, bool *row_mask, cudaStream_t cuda_stream);

@ -20,20 +20,16 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_,
cudaStream_t cuda_stream);
void CalSort(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
cudaStream_t stream);
template <typename T>
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream);
void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *input, T *output, int *index_buff,
int box_size_, bool *row_mask, cudaStream_t cuda_stream);
template <typename T>
void CalSortInit(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
cudaStream_t stream);
template <typename T>
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream);
void CalNMS(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, bool *row_mask,
cudaStream_t cuda_stream);
int NMSRoundUpPower2(int v);

@ -41,21 +41,19 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *area = GetDeviceAddress<T>(workspace, 0); // store area values for all boxes
T *data_buff = GetDeviceAddress<T>(workspace, 1); // sort buffer
T *area = GetDeviceAddress<T>(workspace, 0);
T *data_buff = GetDeviceAddress<T>(workspace, 1);
int *index_buff = GetDeviceAddress<int>(workspace, 2);
bool *row_mask = GetDeviceAddress<bool>(workspace, 3);
T *output = GetDeviceAddress<T>(outputs, 0);
int *sel_idx = GetDeviceAddress<int>(outputs, 1);
bool *sel_boxes = GetDeviceAddress<bool>(outputs, 2);
CalSortInit(num_input_, input, output, index_buff, data_buff, box_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalPreprocess(num_input_, sel_idx, area, input, output, index_buff, box_size_,
CalSort(num_input_, input, output, index_buff, data_buff, box_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalPreprocess(num_input_, sel_idx, sel_boxes, area, input, output, index_buff, box_size_, row_mask,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNMS(num_input_, iou_value_, output, area, sel_boxes, box_size_, row_mask,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -87,8 +85,9 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox
output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool));
workspace_size_ = num_input_ * sizeof(int);
workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int));
workspace_size_ = num_input_ * sizeof(int); // storing areas
workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int)); // sorting buffers
workspace_size_ += (num_input_ * num_input_ * sizeof(bool)); // Row mask - NMS
InitSizeLists();
return true;
@ -103,9 +102,10 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
output_size_list_.push_back(num_input_ * sizeof(bool));
// N sized workspace arrs
workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list
workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff
workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff
workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list
workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff
workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff
workspace_size_list_.push_back(num_input_ * num_input_ * sizeof(bool)); // mask list
}
private:

@ -40,7 +40,7 @@ def test_nms_with_mask_check_order():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
nms_op = P.NMSWithMask(0.5)
for _ in range(10):
count = 8000
count = 4000
box = np.random.randint(1, 100, size=(count, 4))
box[:, 2] = box[:, 0] + box[:, 2]
box[:, 3] = box[:, 1] + box[:, 3]

Loading…
Cancel
Save