implement partial sort

pull/7281/head
wangzhe 4 years ago
parent de989d369d
commit 52e9ba1df7

@ -19,18 +19,84 @@
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"
int ScoreWithIndexCmp(const void *a, const void *b) {
ScoreWithIndex *pa = (ScoreWithIndex *)a;
ScoreWithIndex *pb = (ScoreWithIndex *)b;
bool ScoreWithIndexCmp(ScoreWithIndex *pa, ScoreWithIndex *pb) {
if (pa->score > pb->score) {
return -1;
return true;
} else if (pa->score < pb->score) {
return 1;
return false;
} else {
return pa->index - pb->index;
return pa->index < pb->index;
}
}
void PushHeap(ScoreWithIndex *root, int cur, int top_index, ScoreWithIndex value) {
int parent = (cur - 1) / 2;
while (cur > top_index && ScoreWithIndexCmp(root + parent, &value)) {
*(root + cur) = root[parent];
cur = parent;
parent = (cur - 1) / 2;
}
*(root + cur) = value;
}
void AdjustHeap(ScoreWithIndex *root, int cur, int limit, ScoreWithIndex value) {
int top_index = cur;
int second_child = cur;
while (second_child < (limit - 1) / 2) {
second_child = 2 * (second_child + 1);
if (ScoreWithIndexCmp(root + second_child, root + second_child - 1)) {
second_child--;
}
*(root + cur) = *(root + second_child);
cur = second_child;
}
if ((limit & 1) == 0 && second_child == (limit - 2) / 2) {
second_child = 2 * (second_child + 1);
*(root + cur) = *(root + second_child - 1);
cur = second_child - 1;
}
PushHeap(root, cur, top_index, value);
}
void PopHeap(ScoreWithIndex *root, int limit, ScoreWithIndex *result) {
ScoreWithIndex value = *result;
*result = *root;
AdjustHeap(root, 0, limit, value);
}
void MakeHeap(ScoreWithIndex *values, int limit) {
if (limit < 2) return;
int parent = (limit - 2) / 2;
while (true) {
AdjustHeap(values, parent, limit, values[parent]);
if (parent == 0) {
return;
}
parent--;
}
}
void SortHeap(ScoreWithIndex *root, int limit) {
while (limit > 1) {
--limit;
PopHeap(root, limit, root + limit);
}
}
void HeapSelect(ScoreWithIndex *root, int cur, int limit) {
MakeHeap(root, cur);
for (int i = cur; i < limit; ++i) {
if (ScoreWithIndexCmp(root + i, root)) {
PopHeap(root, cur, root + i);
}
}
}
void PartialSort(ScoreWithIndex *values, int num_to_sort, int num_values) {
HeapSelect(values, num_to_sort, num_values);
SortHeap(values, num_to_sort);
}
float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin);
const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin);
@ -70,7 +136,7 @@ int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const in
const int output_num = candidate_num < max_detections ? candidate_num : max_detections;
int possible_candidate_num = candidate_num;
int selected_num = 0;
qsort(score_with_index, candidate_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp);
PartialSort(score_with_index, candidate_num, candidate_num);
for (int i = 0; i < candidate_num; ++i) {
nms_candidate[i] = 1;
}
@ -134,7 +200,7 @@ int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, c
}
all_classes_output_num =
all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_;
qsort(score_with_index_all, all_classes_sorted_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp);
PartialSort(score_with_index_all, all_classes_output_num, all_classes_sorted_num);
for (int i = 0; i < all_classes_output_num; ++i) {
score_with_index_all[i].index = indexes[score_with_index_all[i].index];
}
@ -178,8 +244,7 @@ int NmsMultiClassesFast(const int num_boxes, const int num_classes_with_bg, cons
// save box and class info to index
score_with_class_all[i * param->num_classes_ + j - first_class_index].index = i * num_classes_with_bg + j;
}
qsort(score_with_class_all + i * param->num_classes_, param->num_classes_, sizeof(ScoreWithIndex),
ScoreWithIndexCmp);
PartialSort(score_with_class_all + i * param->num_classes_, max_classes_per_anchor, param->num_classes_);
const float score_max = (score_with_class_all + i * param->num_classes_)->score;
if (score_max >= param->nms_score_threshold_) {
score_with_class[candidate_num].index = i;

@ -43,10 +43,6 @@ typedef struct {
extern "C" {
#endif
void nms_multi_classes_regular();
void nms_multi_classes_fase();
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores,
float *input_anchors, float *output_boxes, float *output_classes, float *output_scores,
float *output_num, DetectionPostProcessParameter *param);

@ -30,17 +30,12 @@ class DetectionPostProcessCPUKernel : public LiteKernel {
DetectionPostProcessCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<DetectionPostProcessCPUKernel *>(parameter);
}
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~DetectionPostProcessCPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
private:
DetectionPostProcessCPUKernel *param_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_

Loading…
Cancel
Save