diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu new file mode 100644 index 0000000000..36c1b2ee48 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, softwareg + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nms_with_mask_impl.cuh" +#include +#include + +int RoundUpPower2M(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__inline__ __device__ void SwapM(T *lhs, T *rhs) { + T tmp = lhs[0]; + lhs[0] = rhs[0]; + rhs[0] = tmp; +} + +template +__inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area, + float IOU_value) { + T x_1 = max(output[box_A_start + 0], output[box_B_start + 0]); + T y_1 = max(output[box_A_start + 1], output[box_B_start + 1]); + T x_2 = min(output[box_A_start + 2], output[box_B_start + 2]); + T y_2 = min(output[box_A_start + 3], output[box_B_start + 3]); + T width = max(x_2 - x_1, T(0)); // in case of no overlap + T height = max(y_2 - y_1, T(0)); + T combined_area = area[box_A_ix] + area[box_B_ix]; + // return decision to keep or remove box + return !(((width * height) / (combined_area - (width * height))) > IOU_value); +} + +template +__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) { + for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { + sel_idx[box_num] = box_num; + area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) * + (output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]); + } +} + +template +__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, + int box_size_) { + for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { + // represents highest score box in that GPU block + if (threadIdx.x == 0) { + sel_boxes[box_num] = true; + continue; + } + int box_start_index = box_num * box_size_; // start index adjustment + int block_max_box_num = ((blockIdx.x * blockDim.x) + 0); + int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment + sel_boxes[box_num] = + IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area, + IOU_value); // update mask + } +} + +template +__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) { + int box_i, box_j; // access all shared mem meta data with these + int box_i_start_index, box_j_start_index; // actual input data indexing + for (int i = 0; i < num - 1; i++) { + box_i = i; + box_i_start_index = box_i * box_size_; // adjust starting index + if (sel_boxes[box_i]) { + for (int j = i + 1; j < num; j++) { + box_j = j; + box_j_start_index = box_j * box_size_; + if (sel_boxes[box_j]) { + sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value); + } + } + } + } +} + +template +__global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in, + S *data_out, T *index_buff, S *data_buff, int box_size_) { + // default: sort with share memory + extern __shared__ T share_mem_NMS[]; + T *index_arr = share_mem_NMS; + S *data_arr = reinterpret_cast(index_arr + ceil_power2); + // sort with RAM + if (index_buff != nullptr && data_buff != nullptr) { + index_arr = index_buff + blockIdx.x * ceil_power2; + data_arr = data_buff + blockIdx.x * ceil_power2; + } + for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { + index_arr[i] = (i < inner) ? T(i) : std::numeric_limits::max(); + // populated directly from input data + data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits::max(); + } + __syncthreads(); + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (data_arr[tid] > data_arr[tid_comp]) { + SwapM(&index_arr[tid], &index_arr[tid_comp]); + SwapM(&data_arr[tid], &data_arr[tid_comp]); + } + } else { + if (data_arr[tid] < data_arr[tid_comp]) { + SwapM(&index_arr[tid], &index_arr[tid_comp]); + SwapM(&data_arr[tid], &data_arr[tid_comp]); + } + } + } + } + __syncthreads(); + } + } + T correct_index; + for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { + correct_index = index_arr[(inner - 1) - tid]; + // moved data from input to output, correct ordering using sorted index array + for (auto i : {0, 1, 2, 3, 4}) { + data_out[(blockIdx.x * inner + tid) * box_size_ + i] = + data_in[(blockIdx.x * inner + correct_index) * box_size_ + i]; + } + } +} + +template +void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) { + Preprocess<<>>(num, sel_idx, area, output, box_size_); +} + +template +void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, + int box_size_, cudaStream_t stream) { + int ceil_power2 = RoundUpPower2M(inner); + size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); + if (share_mem > SHARED_MEM_PER_BLOCK) { + share_mem = 0; + } else { + data_buff = nullptr; + index_buff = nullptr; + } + int thread = std::min(ceil_power2, GET_THREADS); + BitonicSortByKeyKernelM<<>>(outer, inner, ceil_power2, data_in, data_out, + index_buff, data_buff, box_size_); +} + +template +void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream) { + NMSWithMaskKernel<<>>(num, IOU_value, output, area, sel_boxes, + box_size_); +} + +template +void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream) { + FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_); +} + +template void CalPreprocess(const int num, int *sel_idx, float *area, float *output, int box_size_, + cudaStream_t cuda_stream); + +template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff, + float *data_buff, int box_size_, cudaStream_t stream); + +template void CalNMSWithMask(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, + int box_size_, cudaStream_t cuda_stream); + +template void CalFinalPass(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, + int box_size_, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh new file mode 100644 index 0000000000..0eafd51389 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream); + +template +void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream); + +template +void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, + int box_size_, cudaStream_t stream); + +template +void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc new file mode 100644 index 0000000000..5def6b3af4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(NMSWithMask, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + NMSWithMaskGpuFwdKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h new file mode 100644 index 0000000000..a5e0464cb9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class NMSWithMaskGpuFwdKernel : public GpuKernel { + public: + NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {} + ~NMSWithMaskGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *data_buff = GetDeviceAddress(workspace, 0); // sort buffer + int *index_buff = GetDeviceAddress(workspace, 1); + T *area = GetDeviceAddress(workspace, 2); // store area values for all boxes + T *output = GetDeviceAddress(outputs, 0); + int *sel_idx = GetDeviceAddress(outputs, 1); + bool *sel_boxes = GetDeviceAddress(outputs, 2); + + BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_, + reinterpret_cast(stream_ptr)); + CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast(stream_ptr)); + CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_, + reinterpret_cast(stream_ptr)); + CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + iou_value_ = GetAttr(kernel_node, "iou_threshold"); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but NMSWithMask needs 1 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 3) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but NMSWithMask needs 3 output."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (CHECK_NULL_INPUT(input_shape)) { + MS_LOG(WARNING) << "NMSWithMask input is null"; + InitSizeLists(); + return true; + } + + num_input_ = input_shape[0]; // Get N value in [N,5] data + + input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox + output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool)); + workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int)); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // N sized input/output data + input_size_list_.push_back(num_input_ * sizeof(T) * box_size_); + output_size_list_.push_back(num_input_ * sizeof(T) * box_size_); + output_size_list_.push_back(num_input_ * sizeof(int)); + output_size_list_.push_back(num_input_ * sizeof(bool)); + + // N sized workspace arrs + workspace_size_list_.push_back(num_input_ * sizeof(T)); + workspace_size_list_.push_back(num_input_ * sizeof(int)); + workspace_size_list_.push_back(num_input_ * sizeof(T)); + } + + private: + int num_input_; + float iou_value_; + static const int box_size_ = 5; // pre_defined box width + // int box_size__ = 5; // current size of bboxes + // default values + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ diff --git a/tests/st/ops/gpu/test_nms_with_mask_op.py b/tests/st/ops/gpu/test_nms_with_mask_op.py new file mode 100644 index 0000000000..210a14b3ff --- /dev/null +++ b/tests/st/ops/gpu/test_nms_with_mask_op.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore +from mindspore import Tensor +from mindspore.ops import operations as P + +def manualNMS(bbox, overlap_val_iou): + mask = [True] * len(bbox) + for box_a_index, _ in enumerate(bbox): + if not mask[box_a_index]: + continue # ignore if not in list + box_a = bbox[box_a_index] # select box for value extraction + for box_b_index in range(box_a_index + 1, len(bbox)): + if not mask[box_b_index]: + continue # ignore if not in list + box_b = bbox[box_b_index] + areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) + areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) + overlap_x1 = max(box_a[0], box_b[0]) + overlap_y1 = max(box_a[1], box_b[1]) + overlap_x2 = min(box_a[2], box_b[2]) + overlap_y2 = min(box_a[3], box_b[3]) + width = max((overlap_x2 - overlap_x1), 0) + height = max((overlap_y2 - overlap_y1), 0) + # generate IOU decision + mask[box_b_index] = not ( + (width * height)/(areaA + areaB - (width * height))) > overlap_val_iou + return mask + + +def runMSRun(op, bbox): + inputs = Tensor(bbox, mindspore.float32) + box, _, mask = op(inputs) + box = box.asnumpy() + mask = mask.asnumpy() + sel_idx = np.where(mask) + sel_rows = box[sel_idx][:, 0:4] + sel_score = box[sel_idx][:, -1] + return sel_rows, sel_score + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_check_order(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + nms_op = P.NMSWithMask(0.5) + for _ in range(500): + count = 20 + box = np.random.randint(1, 100, size=(count, 4)) + box[:, 2] = box[:, 0] + box[:, 2] + box[:, 3] = box[:, 1] + box[:, 3] + unsorted_scores = np.random.rand(count, 1) + bbox = np.hstack((box, unsorted_scores)) + bbox = Tensor(bbox, dtype=mindspore.float32) + prop, _, _ = nms_op(bbox) + ms_sorted_scores = (prop.asnumpy()[:, -1]) # select just scores + np_sorted_scores = (np.sort(unsorted_scores, axis=0)[::-1][:, 0]) # sort manually + np.testing.assert_array_almost_equal( + ms_sorted_scores, np_sorted_scores) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_masl_check_result(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_count = 500 + for x in range(1, test_count+1): + count = 20 # size of bbox lists + nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1 + box = np.random.randint(1, 100, size=(count, 4)) + box[:, 2] = box[:, 0] + box[:, 2] + box[:, 3] = box[:, 1] + box[:, 3] + unsorted_scores = np.random.rand(count, 1) + sorted_scores = np.sort(unsorted_scores, axis=0)[::-1] + bbox = np.hstack((box, sorted_scores)) + bbox = Tensor(bbox, dtype=mindspore.float32) + _, _, mask = nms_op(bbox) + mask = mask.asnumpy() + manual_mask = manualNMS(box, x * 0.002) + np.testing.assert_array_equal(mask, np.array(manual_mask)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_edge_case_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # CASE 1 - FULL OVERLAP BOXES - Every box is duplicated and has a different score + nms_op1 = P.NMSWithMask(0.3) + bbox1 = [[12, 4, 33, 17, 0.6], [20, 11, 38, 23, 0.1], [20, 10, 45, 26, 0.9], [15, 17, 35, 38, 0.5], + [10, 20, 30, 40, 0.4], [35, 35, 89, 90, 0.8], [12, 4, 33, 17, 0.3], [20, 11, 38, 23, 0.2], + [20, 10, 45, 26, 0.1], [15, 17, 35, 38, 0.8], [10, 20, 30, 40, 0.41], [35, 35, 89, 90, 0.82]] + expected_bbox = np.array([[20., 10., 45., 26.], + [35., 35., 89., 90.], + [15., 17., 35., 38.], + [12., 4., 33., 17.]]) + expected_score = np.array([0.9, 0.82, 0.8, 0.6]) + + sel_rows, sel_score = runMSRun(nms_op1, bbox1) + np.testing.assert_almost_equal(sel_rows, expected_bbox) + np.testing.assert_almost_equal(sel_score, expected_score) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_edge_case_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # CASE 2 - 0 value boxes - with valid scores + nms_op2 = P.NMSWithMask(0.5) + bbox2 = [[0, 0, 0, 0, 0.6], [0, 0, 0, 0, 0.1]] + expected_bbox = np.array([[0., 0., 0., 0.], + [0., 0., 0., 0.]]) + expected_score = np.array([0.6, 0.1]) + + sel_rows, sel_score = runMSRun(nms_op2, bbox2) + np.testing.assert_almost_equal(sel_rows, expected_bbox) + np.testing.assert_almost_equal(sel_score, expected_score) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nms_with_mask_edge_case_3(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # CASE 3 - x2/x1 and y2/y1 sequence out of place + nms_op3 = P.NMSWithMask(0.7) + bbox3 = [[70, 70, 45, 75, 0.6], [30, 33, 43, 29, 0.1]] + expected_bbox = np.array([[70., 70., 45., 75.], + [30., 33., 43., 29.]]) + expected_score = np.array([0.6, 0.1]) + + sel_rows, sel_score = runMSRun(nms_op3, bbox3) + np.testing.assert_almost_equal(sel_rows, expected_bbox) + np.testing.assert_almost_equal(sel_score, expected_score)