|
|
|
@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
@ -82,7 +87,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<std::map<int, std::vector<Box>>> gt_boxes;
|
|
|
|
|
std::vector<std::map<int, std::vector<std::pair<T, Box>>>> detect_boxes;
|
|
|
|
|
|
|
|
|
|
GetBoxes(*in_label, *in_detect, gt_boxes, detect_boxes);
|
|
|
|
|
GetBoxes(*in_label, *in_detect, >_boxes, detect_boxes);
|
|
|
|
|
|
|
|
|
|
std::map<int, int> label_pos_count;
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>> true_pos;
|
|
|
|
@ -95,20 +100,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (in_pos_count != nullptr && state) {
|
|
|
|
|
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
|
|
|
|
|
true_pos, false_pos, class_num);
|
|
|
|
|
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, &label_pos_count,
|
|
|
|
|
&true_pos, &false_pos, class_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult,
|
|
|
|
|
overlap_threshold, label_pos_count, true_pos,
|
|
|
|
|
false_pos);
|
|
|
|
|
overlap_threshold, &label_pos_count, &true_pos,
|
|
|
|
|
&false_pos);
|
|
|
|
|
|
|
|
|
|
int background_label = ctx.Attr<int>("background_label");
|
|
|
|
|
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos,
|
|
|
|
|
background_label);
|
|
|
|
|
|
|
|
|
|
GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count,
|
|
|
|
|
*out_true_pos, *out_false_pos, class_num);
|
|
|
|
|
GetOutputPos(ctx, label_pos_count, true_pos, false_pos, out_pos_count,
|
|
|
|
|
out_true_pos, out_false_pos, class_num);
|
|
|
|
|
|
|
|
|
|
T* map_data = out_map->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
map_data[0] = map;
|
|
|
|
@ -155,7 +160,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
void GetBoxes(const framework::LoDTensor& input_label,
|
|
|
|
|
const framework::LoDTensor& input_detect,
|
|
|
|
|
std::vector<std::map<int, std::vector<Box>>>& gt_boxes,
|
|
|
|
|
std::vector<std::map<int, std::vector<Box>>>* gt_boxes,
|
|
|
|
|
std::vector<std::map<int, std::vector<std::pair<T, Box>>>>&
|
|
|
|
|
detect_boxes) const {
|
|
|
|
|
auto labels = framework::EigenTensor<T, 2>::From(input_label);
|
|
|
|
@ -179,7 +184,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
box.is_difficult = true;
|
|
|
|
|
boxes[label].push_back(box);
|
|
|
|
|
}
|
|
|
|
|
gt_boxes.push_back(boxes);
|
|
|
|
|
gt_boxes->push_back(boxes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto detect_index = detect_lod[0];
|
|
|
|
@ -200,9 +205,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const std::map<int, int>& label_pos_count,
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
|
|
|
|
|
framework::Tensor& output_pos_count,
|
|
|
|
|
framework::LoDTensor& output_true_pos,
|
|
|
|
|
framework::LoDTensor& output_false_pos, const int class_num) const {
|
|
|
|
|
framework::Tensor* output_pos_count,
|
|
|
|
|
framework::LoDTensor* output_true_pos,
|
|
|
|
|
framework::LoDTensor* output_false_pos, const int class_num) const {
|
|
|
|
|
int true_pos_count = 0;
|
|
|
|
|
int false_pos_count = 0;
|
|
|
|
|
for (auto it = true_pos.begin(); it != true_pos.end(); ++it) {
|
|
|
|
@ -214,12 +219,12 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
false_pos_count += fp.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int* pos_count_data = output_pos_count.mutable_data<int>(
|
|
|
|
|
int* pos_count_data = output_pos_count->mutable_data<int>(
|
|
|
|
|
framework::make_ddim({class_num, 1}), ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
T* true_pos_data = output_true_pos.mutable_data<T>(
|
|
|
|
|
T* true_pos_data = output_true_pos->mutable_data<T>(
|
|
|
|
|
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
|
|
|
|
|
T* false_pos_data = output_false_pos.mutable_data<T>(
|
|
|
|
|
T* false_pos_data = output_false_pos->mutable_data<T>(
|
|
|
|
|
framework::make_ddim({false_pos_count, 2}), ctx.GetPlace());
|
|
|
|
|
true_pos_count = 0;
|
|
|
|
|
false_pos_count = 0;
|
|
|
|
@ -261,21 +266,21 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
framework::LoD false_pos_lod;
|
|
|
|
|
false_pos_lod.emplace_back(false_pos_starts);
|
|
|
|
|
|
|
|
|
|
output_true_pos.set_lod(true_pos_lod);
|
|
|
|
|
output_false_pos.set_lod(false_pos_lod);
|
|
|
|
|
output_true_pos->set_lod(true_pos_lod);
|
|
|
|
|
output_false_pos->set_lod(false_pos_lod);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetInputPos(const framework::Tensor& input_pos_count,
|
|
|
|
|
const framework::LoDTensor& input_true_pos,
|
|
|
|
|
const framework::LoDTensor& input_false_pos,
|
|
|
|
|
std::map<int, int>& label_pos_count,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& false_pos,
|
|
|
|
|
std::map<int, int>* label_pos_count,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>* true_pos,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>* false_pos,
|
|
|
|
|
const int class_num) const {
|
|
|
|
|
const int* pos_count_data = input_pos_count.data<int>();
|
|
|
|
|
for (int i = 0; i < class_num; ++i) {
|
|
|
|
|
label_pos_count[i] = pos_count_data[i];
|
|
|
|
|
(*label_pos_count)[i] = pos_count_data[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto SetData = [](const framework::LoDTensor& pos_tensor,
|
|
|
|
@ -291,8 +296,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
SetData(input_true_pos, true_pos);
|
|
|
|
|
SetData(input_false_pos, false_pos);
|
|
|
|
|
SetData(input_true_pos, *true_pos);
|
|
|
|
|
SetData(input_false_pos, *false_pos);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -301,9 +306,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const std::vector<std::map<int, std::vector<std::pair<T, Box>>>>&
|
|
|
|
|
detect_boxes,
|
|
|
|
|
bool evaluate_difficult, float overlap_threshold,
|
|
|
|
|
std::map<int, int>& label_pos_count,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
|
|
|
|
|
std::map<int, int>* label_pos_count,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>* true_pos,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>* false_pos) const {
|
|
|
|
|
int batch_size = gt_boxes.size();
|
|
|
|
|
for (int n = 0; n < batch_size; ++n) {
|
|
|
|
|
auto image_gt_boxes = gt_boxes[n];
|
|
|
|
@ -320,10 +325,10 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
int label = it->first;
|
|
|
|
|
if (label_pos_count.find(label) == label_pos_count.end()) {
|
|
|
|
|
label_pos_count[label] = count;
|
|
|
|
|
if (label_pos_count->find(label) == label_pos_count->end()) {
|
|
|
|
|
(*label_pos_count)[label] = count;
|
|
|
|
|
} else {
|
|
|
|
|
label_pos_count[label] += count;
|
|
|
|
|
(*label_pos_count)[label] += count;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -338,8 +343,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
int label = it->first;
|
|
|
|
|
for (size_t i = 0; i < pred_boxes.size(); ++i) {
|
|
|
|
|
auto score = pred_boxes[i].first;
|
|
|
|
|
true_pos[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
false_pos[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
(*true_pos)[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
(*false_pos)[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
@ -351,8 +356,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (image_gt_boxes.find(label) == image_gt_boxes.end()) {
|
|
|
|
|
for (size_t i = 0; i < pred_boxes.size(); ++i) {
|
|
|
|
|
auto score = pred_boxes[i].first;
|
|
|
|
|
true_pos[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
false_pos[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
(*true_pos)[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
(*false_pos)[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -381,17 +386,17 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
(!evaluate_difficult && !matched_bboxes[max_idx].is_difficult);
|
|
|
|
|
if (match_evaluate_difficult) {
|
|
|
|
|
if (!visited[max_idx]) {
|
|
|
|
|
true_pos[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
false_pos[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
(*true_pos)[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
(*false_pos)[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
visited[max_idx] = true;
|
|
|
|
|
} else {
|
|
|
|
|
true_pos[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
false_pos[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
(*true_pos)[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
(*false_pos)[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
true_pos[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
false_pos[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
(*true_pos)[label].push_back(std::make_pair(score, 0));
|
|
|
|
|
(*false_pos)[label].push_back(std::make_pair(score, 1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|