diff --git a/mindspore/lite/nnacl/fp32/detection_post_process.c b/mindspore/lite/nnacl/fp32/detection_post_process.c index f7c6953e2c..da54d3f998 100644 --- a/mindspore/lite/nnacl/fp32/detection_post_process.c +++ b/mindspore/lite/nnacl/fp32/detection_post_process.c @@ -19,18 +19,84 @@ #include "nnacl/errorcode.h" #include "nnacl/op_base.h" -int ScoreWithIndexCmp(const void *a, const void *b) { - ScoreWithIndex *pa = (ScoreWithIndex *)a; - ScoreWithIndex *pb = (ScoreWithIndex *)b; +bool ScoreWithIndexCmp(ScoreWithIndex *pa, ScoreWithIndex *pb) { if (pa->score > pb->score) { - return -1; + return true; } else if (pa->score < pb->score) { - return 1; + return false; } else { - return pa->index - pb->index; + return pa->index < pb->index; } } +void PushHeap(ScoreWithIndex *root, int cur, int top_index, ScoreWithIndex value) { + int parent = (cur - 1) / 2; + while (cur > top_index && ScoreWithIndexCmp(root + parent, &value)) { + *(root + cur) = root[parent]; + cur = parent; + parent = (cur - 1) / 2; + } + *(root + cur) = value; +} + +void AdjustHeap(ScoreWithIndex *root, int cur, int limit, ScoreWithIndex value) { + int top_index = cur; + int second_child = cur; + while (second_child < (limit - 1) / 2) { + second_child = 2 * (second_child + 1); + if (ScoreWithIndexCmp(root + second_child, root + second_child - 1)) { + second_child--; + } + *(root + cur) = *(root + second_child); + cur = second_child; + } + if ((limit & 1) == 0 && second_child == (limit - 2) / 2) { + second_child = 2 * (second_child + 1); + *(root + cur) = *(root + second_child - 1); + cur = second_child - 1; + } + PushHeap(root, cur, top_index, value); +} + +void PopHeap(ScoreWithIndex *root, int limit, ScoreWithIndex *result) { + ScoreWithIndex value = *result; + *result = *root; + AdjustHeap(root, 0, limit, value); +} + +void MakeHeap(ScoreWithIndex *values, int limit) { + if (limit < 2) return; + int parent = (limit - 2) / 2; + while (true) { + AdjustHeap(values, parent, limit, values[parent]); + if (parent == 0) { + return; + } + parent--; + } +} + +void SortHeap(ScoreWithIndex *root, int limit) { + while (limit > 1) { + --limit; + PopHeap(root, limit, root + limit); + } +} + +void HeapSelect(ScoreWithIndex *root, int cur, int limit) { + MakeHeap(root, cur); + for (int i = cur; i < limit; ++i) { + if (ScoreWithIndexCmp(root + i, root)) { + PopHeap(root, cur, root + i); + } + } +} + +void PartialSort(ScoreWithIndex *values, int num_to_sort, int num_values) { + HeapSelect(values, num_to_sort, num_values); + SortHeap(values, num_to_sort); +} + float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) { const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin); const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin); @@ -70,7 +136,7 @@ int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const in const int output_num = candidate_num < max_detections ? candidate_num : max_detections; int possible_candidate_num = candidate_num; int selected_num = 0; - qsort(score_with_index, candidate_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp); + PartialSort(score_with_index, candidate_num, candidate_num); for (int i = 0; i < candidate_num; ++i) { nms_candidate[i] = 1; } @@ -134,7 +200,7 @@ int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, c } all_classes_output_num = all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_; - qsort(score_with_index_all, all_classes_sorted_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp); + PartialSort(score_with_index_all, all_classes_output_num, all_classes_sorted_num); for (int i = 0; i < all_classes_output_num; ++i) { score_with_index_all[i].index = indexes[score_with_index_all[i].index]; } @@ -178,8 +244,7 @@ int NmsMultiClassesFast(const int num_boxes, const int num_classes_with_bg, cons // save box and class info to index score_with_class_all[i * param->num_classes_ + j - first_class_index].index = i * num_classes_with_bg + j; } - qsort(score_with_class_all + i * param->num_classes_, param->num_classes_, sizeof(ScoreWithIndex), - ScoreWithIndexCmp); + PartialSort(score_with_class_all + i * param->num_classes_, max_classes_per_anchor, param->num_classes_); const float score_max = (score_with_class_all + i * param->num_classes_)->score; if (score_max >= param->nms_score_threshold_) { score_with_class[candidate_num].index = i; diff --git a/mindspore/lite/nnacl/fp32/detection_post_process.h b/mindspore/lite/nnacl/fp32/detection_post_process.h index 208efa228a..3bfa566303 100644 --- a/mindspore/lite/nnacl/fp32/detection_post_process.h +++ b/mindspore/lite/nnacl/fp32/detection_post_process.h @@ -43,10 +43,6 @@ typedef struct { extern "C" { #endif -void nms_multi_classes_regular(); - -void nms_multi_classes_fase(); - int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores, float *input_anchors, float *output_boxes, float *output_classes, float *output_scores, float *output_num, DetectionPostProcessParameter *param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h b/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h index c51befff44..6e78aec804 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h @@ -30,17 +30,12 @@ class DetectionPostProcessCPUKernel : public LiteKernel { DetectionPostProcessCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - param_ = reinterpret_cast(parameter); - } + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~DetectionPostProcessCPUKernel() override; int Init() override; int ReSize() override; int Run() override; - - private: - DetectionPostProcessCPUKernel *param_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_