|
|
|
@ -69,6 +69,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
float overlap_threshold = ctx.Attr<float>("overlap_threshold");
|
|
|
|
|
float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult");
|
|
|
|
|
auto ap_type = GetAPType(ctx.Attr<std::string>("ap_type"));
|
|
|
|
|
int class_num = ctx.Attr<int>("class_num");
|
|
|
|
|
|
|
|
|
|
auto label_lod = in_label->lod();
|
|
|
|
|
auto detect_lod = in_detect->lod();
|
|
|
|
@ -95,17 +96,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
if (in_pos_count != nullptr && state) {
|
|
|
|
|
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
|
|
|
|
|
true_pos, false_pos);
|
|
|
|
|
true_pos, false_pos, class_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult,
|
|
|
|
|
overlap_threshold, label_pos_count, true_pos,
|
|
|
|
|
false_pos);
|
|
|
|
|
|
|
|
|
|
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos);
|
|
|
|
|
int background_label = ctx.Attr<int>("background_label");
|
|
|
|
|
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos,
|
|
|
|
|
background_label);
|
|
|
|
|
|
|
|
|
|
GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count,
|
|
|
|
|
*out_true_pos, *out_false_pos);
|
|
|
|
|
*out_true_pos, *out_false_pos, class_num);
|
|
|
|
|
|
|
|
|
|
T* map_data = out_map->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
map_data[0] = map;
|
|
|
|
@ -190,24 +193,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
|
|
|
|
|
framework::Tensor& output_pos_count,
|
|
|
|
|
framework::LoDTensor& output_true_pos,
|
|
|
|
|
framework::LoDTensor& output_false_pos) const {
|
|
|
|
|
int max_class_id = 0;
|
|
|
|
|
framework::LoDTensor& output_false_pos, const int class_num) const {
|
|
|
|
|
int true_pos_count = 0;
|
|
|
|
|
int false_pos_count = 0;
|
|
|
|
|
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
|
|
|
|
|
int label = it->first;
|
|
|
|
|
if (label > max_class_id) max_class_id = label;
|
|
|
|
|
int label_num_pos = it->second;
|
|
|
|
|
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end())
|
|
|
|
|
continue;
|
|
|
|
|
auto label_true_pos = true_pos.find(label)->second;
|
|
|
|
|
auto label_false_pos = false_pos.find(label)->second;
|
|
|
|
|
true_pos_count += label_true_pos.size();
|
|
|
|
|
false_pos_count += label_false_pos.size();
|
|
|
|
|
for (auto it = true_pos.begin(); it != true_pos.end(); ++it) {
|
|
|
|
|
auto tp = it->second;
|
|
|
|
|
true_pos_count += tp.size();
|
|
|
|
|
}
|
|
|
|
|
for (auto it = false_pos.begin(); it != false_pos.end(); ++it) {
|
|
|
|
|
auto fp = it->second;
|
|
|
|
|
false_pos_count += fp.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int* pos_count_data = output_pos_count.mutable_data<int>(
|
|
|
|
|
framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace());
|
|
|
|
|
framework::make_ddim({class_num, 1}), ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
T* true_pos_data = output_true_pos.mutable_data<T>(
|
|
|
|
|
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
|
|
|
|
@ -217,7 +216,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
false_pos_count = 0;
|
|
|
|
|
std::vector<size_t> true_pos_starts = {0};
|
|
|
|
|
std::vector<size_t> false_pos_starts = {0};
|
|
|
|
|
for (int i = 0; i <= max_class_id; ++i) {
|
|
|
|
|
for (int i = 0; i < class_num; ++i) {
|
|
|
|
|
auto it_count = label_pos_count.find(i);
|
|
|
|
|
pos_count_data[i] = 0;
|
|
|
|
|
if (it_count != label_pos_count.end()) {
|
|
|
|
@ -258,17 +257,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetInputPos(
|
|
|
|
|
const framework::Tensor& input_pos_count,
|
|
|
|
|
const framework::LoDTensor& input_true_pos,
|
|
|
|
|
const framework::LoDTensor& input_false_pos,
|
|
|
|
|
std::map<int, int>& label_pos_count,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
|
|
|
|
|
void GetInputPos(const framework::Tensor& input_pos_count,
|
|
|
|
|
const framework::LoDTensor& input_true_pos,
|
|
|
|
|
const framework::LoDTensor& input_false_pos,
|
|
|
|
|
std::map<int, int>& label_pos_count,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
std::map<int, std::vector<std::pair<T, int>>>& false_pos,
|
|
|
|
|
const int class_num) const {
|
|
|
|
|
constexpr T kEPS = static_cast<T>(1e-6);
|
|
|
|
|
int class_number = input_pos_count.dims()[0];
|
|
|
|
|
const int* pos_count_data = input_pos_count.data<int>();
|
|
|
|
|
for (int i = 0; i < class_number; ++i) {
|
|
|
|
|
for (int i = 0; i < class_num; ++i) {
|
|
|
|
|
label_pos_count[i] = pos_count_data[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -391,17 +389,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T CalcMAP(
|
|
|
|
|
APType ap_type, const std::map<int, int>& label_pos_count,
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
|
|
|
|
|
T CalcMAP(APType ap_type, const std::map<int, int>& label_pos_count,
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
|
|
|
|
|
const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
|
|
|
|
|
const int background_label) const {
|
|
|
|
|
T mAP = 0.0;
|
|
|
|
|
int count = 0;
|
|
|
|
|
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
|
|
|
|
|
int label = it->first;
|
|
|
|
|
int label_num_pos = it->second;
|
|
|
|
|
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end())
|
|
|
|
|
if (label_num_pos == background_label ||
|
|
|
|
|
true_pos.find(label) == true_pos.end()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto label_true_pos = true_pos.find(label)->second;
|
|
|
|
|
auto label_false_pos = false_pos.find(label)->second;
|
|
|
|
|
// Compute average precision.
|
|
|
|
@ -450,7 +450,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (count != 0) mAP /= count;
|
|
|
|
|
return mAP * 100;
|
|
|
|
|
return mAP;
|
|
|
|
|
}
|
|
|
|
|
}; // namespace operators
|
|
|
|
|
|
|
|
|
|