diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu index 36c1b2ee48..48afc57576 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu @@ -18,7 +18,7 @@ #include #include -int RoundUpPower2M(int v) { +int NMSRoundUpPower2(int v) { v--; v |= v >> 1; v |= v >> 2; @@ -30,12 +30,22 @@ int RoundUpPower2M(int v) { } template -__inline__ __device__ void SwapM(T *lhs, T *rhs) { +__inline__ __device__ void Swap(T *lhs, T *rhs) { T tmp = lhs[0]; lhs[0] = rhs[0]; rhs[0] = tmp; } +template +__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_) { + for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { + int correct_index = index_buff[(num - 1) - box_num]; // flip the array around + for (int x = 0; x < 5; x++) { + data_out[(box_num * box_size_) + x] = data_in[(correct_index * box_size_) + x]; + } + } +} + template __inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area, float IOU_value) { @@ -96,38 +106,29 @@ __global__ void FinalPass(const int num, const float IOU_value, T *output, T *ar } } -template -__global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in, - S *data_out, T *index_buff, S *data_buff, int box_size_) { - // default: sort with share memory - extern __shared__ T share_mem_NMS[]; - T *index_arr = share_mem_NMS; - S *data_arr = reinterpret_cast(index_arr + ceil_power2); - // sort with RAM - if (index_buff != nullptr && data_buff != nullptr) { - index_arr = index_buff + blockIdx.x * ceil_power2; - data_arr = data_buff + blockIdx.x * ceil_power2; - } +template +__global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input, + T *data_buff, int *index_buff, int box_size_) { for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { - index_arr[i] = (i < inner) ? T(i) : std::numeric_limits::max(); - // populated directly from input data - data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits::max(); + data_buff[i] = (i < inner) ? input[(i * box_size_) + 4] : std::numeric_limits::max(); + index_buff[i] = i; } __syncthreads(); + for (size_t i = 2; i <= ceil_power2; i <<= 1) { for (size_t j = (i >> 1); j > 0; j >>= 1) { for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { size_t tid_comp = tid ^ j; if (tid_comp > tid) { if ((tid & i) == 0) { - if (data_arr[tid] > data_arr[tid_comp]) { - SwapM(&index_arr[tid], &index_arr[tid_comp]); - SwapM(&data_arr[tid], &data_arr[tid_comp]); + if (data_buff[tid] > data_buff[tid_comp]) { + Swap(&data_buff[tid], &data_buff[tid_comp]); + Swap(&index_buff[tid], &index_buff[tid_comp]); } } else { - if (data_arr[tid] < data_arr[tid_comp]) { - SwapM(&index_arr[tid], &index_arr[tid_comp]); - SwapM(&data_arr[tid], &data_arr[tid_comp]); + if (data_buff[tid] < data_buff[tid_comp]) { + Swap(&data_buff[tid], &data_buff[tid_comp]); + Swap(&index_buff[tid], &index_buff[tid_comp]); } } } @@ -135,36 +136,21 @@ __global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const __syncthreads(); } } - T correct_index; - for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { - correct_index = index_arr[(inner - 1) - tid]; - // moved data from input to output, correct ordering using sorted index array - for (auto i : {0, 1, 2, 3, 4}) { - data_out[(blockIdx.x * inner + tid) * box_size_ + i] = - data_in[(blockIdx.x * inner + correct_index) * box_size_ + i]; - } - } } template -void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) { +void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_, + cudaStream_t cuda_stream) { + PopulateOutput<<>>(input, output, index_buff, num, box_size_); Preprocess<<>>(num, sel_idx, area, output, box_size_); } -template -void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, - int box_size_, cudaStream_t stream) { - int ceil_power2 = RoundUpPower2M(inner); - size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); - if (share_mem > SHARED_MEM_PER_BLOCK) { - share_mem = 0; - } else { - data_buff = nullptr; - index_buff = nullptr; - } - int thread = std::min(ceil_power2, GET_THREADS); - BitonicSortByKeyKernelM<<>>(outer, inner, ceil_power2, data_in, data_out, - index_buff, data_buff, box_size_); +template +void CalSortInit(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, + cudaStream_t stream) { + int ceil_p_2 = NMSRoundUpPower2(num); + int thread = std::min(ceil_p_2, GET_THREADS); + NMS_BitonicSortByKeyKernel<<<1, thread, 0, stream>>>(1, num, ceil_p_2, data_in, data_buff, index_buff, box_size_); } template @@ -180,11 +166,11 @@ void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_); } -template void CalPreprocess(const int num, int *sel_idx, float *area, float *output, int box_size_, - cudaStream_t cuda_stream); +template void CalPreprocess(const int num, int *sel_idx, float *area, float *input, float *output, + int *index_buff, int box_size_, cudaStream_t cuda_stream); -template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff, - float *data_buff, int box_size_, cudaStream_t stream); +template void CalSortInit(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff, + int box_size_, cudaStream_t stream); template void CalNMSWithMask(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, int box_size_, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh index 0eafd51389..b20c6704ed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh @@ -20,18 +20,21 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream); +void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_, + cudaStream_t cuda_stream); template void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, cudaStream_t cuda_stream); -template -void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, - int box_size_, cudaStream_t stream); +template +void CalSortInit(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, + cudaStream_t stream); template void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, cudaStream_t cuda_stream); +int NMSRoundUpPower2(int v); + #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h index a5e0464cb9..7c22d80ba7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h @@ -30,7 +30,8 @@ namespace kernel { template class NMSWithMaskGpuFwdKernel : public GpuKernel { public: - NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {} + NMSWithMaskGpuFwdKernel() + : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0), ceil_power_2(0) {} ~NMSWithMaskGpuFwdKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,22 +41,24 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { T *input = GetDeviceAddress(inputs, 0); - T *data_buff = GetDeviceAddress(workspace, 0); // sort buffer - int *index_buff = GetDeviceAddress(workspace, 1); - T *area = GetDeviceAddress(workspace, 2); // store area values for all boxes + T *area = GetDeviceAddress(workspace, 0); // store area values for all boxes + T *data_buff = GetDeviceAddress(workspace, 1); // sort buffer + int *index_buff = GetDeviceAddress(workspace, 2); T *output = GetDeviceAddress(outputs, 0); int *sel_idx = GetDeviceAddress(outputs, 1); bool *sel_boxes = GetDeviceAddress(outputs, 2); - BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_, - reinterpret_cast(stream_ptr)); - CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast(stream_ptr)); + CalSortInit(num_input_, input, output, index_buff, data_buff, box_size_, + reinterpret_cast(stream_ptr)); + CalPreprocess(num_input_, sel_idx, area, input, output, index_buff, box_size_, + reinterpret_cast(stream_ptr)); CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_, reinterpret_cast(stream_ptr)); CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_, reinterpret_cast(stream_ptr)); return true; } + bool Init(const CNodePtr &kernel_node) override { iou_value_ = GetAttr(kernel_node, "iou_threshold"); @@ -79,10 +82,13 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { } num_input_ = input_shape[0]; // Get N value in [N,5] data + ceil_power_2 = NMSRoundUpPower2(num_input_); input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool)); - workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int)); + + workspace_size_ = num_input_ * sizeof(int); + workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int)); InitSizeLists(); return true; @@ -97,20 +103,20 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { output_size_list_.push_back(num_input_ * sizeof(bool)); // N sized workspace arrs - workspace_size_list_.push_back(num_input_ * sizeof(T)); - workspace_size_list_.push_back(num_input_ * sizeof(int)); - workspace_size_list_.push_back(num_input_ * sizeof(T)); + workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list + workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff + workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff } private: int num_input_; float iou_value_; static const int box_size_ = 5; // pre_defined box width - // int box_size__ = 5; // current size of bboxes // default values size_t input_size_; size_t output_size_; size_t workspace_size_; + size_t ceil_power_2; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/tests/st/ops/gpu/test_nms_with_mask_op.py b/tests/st/ops/gpu/test_nms_with_mask_op.py index 210a14b3ff..ed0be91a84 100644 --- a/tests/st/ops/gpu/test_nms_with_mask_op.py +++ b/tests/st/ops/gpu/test_nms_with_mask_op.py @@ -21,29 +21,6 @@ import mindspore from mindspore import Tensor from mindspore.ops import operations as P -def manualNMS(bbox, overlap_val_iou): - mask = [True] * len(bbox) - for box_a_index, _ in enumerate(bbox): - if not mask[box_a_index]: - continue # ignore if not in list - box_a = bbox[box_a_index] # select box for value extraction - for box_b_index in range(box_a_index + 1, len(bbox)): - if not mask[box_b_index]: - continue # ignore if not in list - box_b = bbox[box_b_index] - areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) - areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) - overlap_x1 = max(box_a[0], box_b[0]) - overlap_y1 = max(box_a[1], box_b[1]) - overlap_x2 = min(box_a[2], box_b[2]) - overlap_y2 = min(box_a[3], box_b[3]) - width = max((overlap_x2 - overlap_x1), 0) - height = max((overlap_y2 - overlap_y1), 0) - # generate IOU decision - mask[box_b_index] = not ( - (width * height)/(areaA + areaB - (width * height))) > overlap_val_iou - return mask - def runMSRun(op, bbox): inputs = Tensor(bbox, mindspore.float32) @@ -60,10 +37,10 @@ def runMSRun(op, bbox): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_nms_with_mask_check_order(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") nms_op = P.NMSWithMask(0.5) - for _ in range(500): - count = 20 + for _ in range(10): + count = 8000 box = np.random.randint(1, 100, size=(count, 4)) box[:, 2] = box[:, 0] + box[:, 2] box[:, 3] = box[:, 1] + box[:, 3] @@ -77,28 +54,6 @@ def test_nms_with_mask_check_order(): ms_sorted_scores, np_sorted_scores) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_nms_with_masl_check_result(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - test_count = 500 - for x in range(1, test_count+1): - count = 20 # size of bbox lists - nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1 - box = np.random.randint(1, 100, size=(count, 4)) - box[:, 2] = box[:, 0] + box[:, 2] - box[:, 3] = box[:, 1] + box[:, 3] - unsorted_scores = np.random.rand(count, 1) - sorted_scores = np.sort(unsorted_scores, axis=0)[::-1] - bbox = np.hstack((box, sorted_scores)) - bbox = Tensor(bbox, dtype=mindspore.float32) - _, _, mask = nms_op(bbox) - mask = mask.asnumpy() - manual_mask = manualNMS(box, x * 0.002) - np.testing.assert_array_equal(mask, np.array(manual_mask)) - - @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard