parent
e72b865cb1
commit
67cbb3e3b6
@ -0,0 +1,77 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/detection_map_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class DetectionMAPOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
auto map_dim = framework::make_ddim({1});
|
||||
ctx->SetOutputDim("MAP", map_dim);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("Label")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
DetectionMAPOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Detect", "The detection output.");
|
||||
AddInput("Label", "The label data.");
|
||||
AddOutput("MAP", "The MAP evaluate result of the detection.");
|
||||
|
||||
AddAttr<float>("overlap_threshold", "The overlap threshold.")
|
||||
.SetDefault(.3f);
|
||||
AddAttr<bool>("evaluate_difficult",
|
||||
"Switch to control whether the difficult data is evaluated.")
|
||||
.SetDefault(true);
|
||||
AddAttr<std::string>("ap_type",
|
||||
"The AP algorithm type, 'Integral' or '11point'.")
|
||||
.SetDefault("Integral");
|
||||
|
||||
AddComment(R"DOC(
|
||||
Detection MAP Operator.
|
||||
|
||||
Detection MAP evaluator for SSD(Single Shot MultiBox Detector) algorithm.
|
||||
Please get more information from the following papers:
|
||||
https://arxiv.org/abs/1512.02325.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp,
|
||||
ops::DetectionMAPOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
detection_map, ops::DetectionMAPOpKernel<paddle::platform::GPUPlace, float>,
|
||||
ops::DetectionMAPOpKernel<paddle::platform::GPUPlace, double>);
|
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/detection_map_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
detection_map, ops::DetectionMAPOpKernel<paddle::platform::GPUPlace, float>,
|
||||
ops::DetectionMAPOpKernel<paddle::platform::GPUPlace, double>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/detection_util.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/detection_util.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#pragma once
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
struct BBox {
|
||||
BBox(T x_min, T y_min, T x_max, T y_max)
|
||||
: x_min(x_min),
|
||||
y_min(y_min),
|
||||
x_max(x_max),
|
||||
y_max(y_max),
|
||||
is_difficult(false) {}
|
||||
|
||||
BBox() {}
|
||||
|
||||
T get_width() const { return x_max - x_min; }
|
||||
|
||||
T get_height() const { return y_max - y_min; }
|
||||
|
||||
T get_center_x() const { return (x_min + x_max) / 2; }
|
||||
|
||||
T get_center_y() const { return (y_min + y_max) / 2; }
|
||||
|
||||
T get_area() const { return get_width() * get_height(); }
|
||||
|
||||
// coordinate of bounding box
|
||||
T x_min;
|
||||
T y_min;
|
||||
T x_max;
|
||||
T y_max;
|
||||
// whether difficult object (e.g. object with heavy occlusion is difficult)
|
||||
bool is_difficult;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void GetBBoxFromDetectData(const T* detect_data, const size_t num_bboxes,
|
||||
std::vector<T>& labels, std::vector<T>& scores,
|
||||
std::vector<BBox<T>>& bboxes) {
|
||||
size_t out_offset = bboxes.size();
|
||||
labels.resize(out_offset + num_bboxes);
|
||||
scores.resize(out_offset + num_bboxes);
|
||||
bboxes.resize(out_offset + num_bboxes);
|
||||
for (size_t i = 0; i < num_bboxes; ++i) {
|
||||
labels[out_offset + i] = *(detect_data + i * 7 + 1);
|
||||
scores[out_offset + i] = *(detect_data + i * 7 + 2);
|
||||
BBox<T> bbox;
|
||||
bbox.x_min = *(detect_data + i * 7 + 3);
|
||||
bbox.y_min = *(detect_data + i * 7 + 4);
|
||||
bbox.x_max = *(detect_data + i * 7 + 5);
|
||||
bbox.y_max = *(detect_data + i * 7 + 6);
|
||||
bboxes[out_offset + i] = bbox;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GetBBoxFromLabelData(const T* label_data, const size_t num_bboxes,
|
||||
std::vector<BBox<T>>& bboxes) {
|
||||
size_t out_offset = bboxes.size();
|
||||
bboxes.resize(bboxes.size() + num_bboxes);
|
||||
for (size_t i = 0; i < num_bboxes; ++i) {
|
||||
BBox<T> bbox;
|
||||
bbox.x_min = *(label_data + i * 6 + 1);
|
||||
bbox.y_min = *(label_data + i * 6 + 2);
|
||||
bbox.x_max = *(label_data + i * 6 + 3);
|
||||
bbox.y_max = *(label_data + i * 6 + 4);
|
||||
T is_difficult = *(label_data + i * 6 + 5);
|
||||
if (std::abs(is_difficult - 0.0) < 1e-6)
|
||||
bbox.is_difficult = false;
|
||||
else
|
||||
bbox.is_difficult = true;
|
||||
bboxes[out_offset + i] = bbox;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline float JaccardOverlap(const BBox<T>& bbox1, const BBox<T>& bbox2) {
|
||||
if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min ||
|
||||
bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) {
|
||||
return 0.0;
|
||||
} else {
|
||||
float inter_x_min = std::max(bbox1.x_min, bbox2.x_min);
|
||||
float inter_y_min = std::max(bbox1.y_min, bbox2.y_min);
|
||||
float inter_x_max = std::min(bbox1.x_max, bbox2.x_max);
|
||||
float inter_y_max = std::min(bbox1.y_max, bbox2.y_max);
|
||||
|
||||
float inter_width = inter_x_max - inter_x_min;
|
||||
float inter_height = inter_y_max - inter_y_min;
|
||||
float inter_area = inter_width * inter_height;
|
||||
|
||||
float bbox_area1 = bbox1.get_area();
|
||||
float bbox_area2 = bbox2.get_area();
|
||||
|
||||
return inter_area / (bbox_area1 + bbox_area2 - inter_area);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SortScorePairDescend(const std::pair<float, T>& pair1,
|
||||
const std::pair<float, T>& pair2) {
|
||||
return pair1.first > pair2.first;
|
||||
}
|
||||
|
||||
// template <>
|
||||
// bool SortScorePairDescend(const std::pair<float, NormalizedBBox>& pair1,
|
||||
// const std::pair<float, NormalizedBBox>& pair2) {
|
||||
// return pair1.first > pair2.first;
|
||||
// }
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,155 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import collections
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestDetectionMAPOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_case()
|
||||
|
||||
self.mAP = [self.calc_map(self.tf_pos)]
|
||||
self.label = np.array(self.label).astype('float32')
|
||||
self.detect = np.array(self.detect).astype('float32')
|
||||
self.mAP = np.array(self.mAP).astype('float32')
|
||||
|
||||
self.inputs = {
|
||||
'Label': (self.label, self.label_lod),
|
||||
'Detect': self.detect
|
||||
}
|
||||
|
||||
self.attrs = {
|
||||
'overlap_threshold': self.overlap_threshold,
|
||||
'evaluate_difficult': self.evaluate_difficult,
|
||||
'ap_type': self.ap_type
|
||||
}
|
||||
|
||||
self.outputs = {'MAP': self.mAP}
|
||||
|
||||
def init_test_case(self):
|
||||
self.overlap_threshold = 0.3
|
||||
self.evaluate_difficult = True
|
||||
self.ap_type = "Integral"
|
||||
|
||||
self.label_lod = [[0, 2, 4]]
|
||||
# label xmin ymin xmax ymax difficult
|
||||
self.label = [[1, 0.1, 0.1, 0.3, 0.3, 0], [1, 0.6, 0.6, 0.8, 0.8, 1],
|
||||
[2, 0.3, 0.3, 0.6, 0.5, 0], [1, 0.7, 0.1, 0.9, 0.3, 0]]
|
||||
|
||||
# image_id label score xmin ymin xmax ymax difficult
|
||||
self.detect = [
|
||||
[0, 1, 0.3, 0.1, 0.0, 0.4, 0.3], [0, 1, 0.7, 0.0, 0.1, 0.2, 0.3],
|
||||
[0, 1, 0.9, 0.7, 0.6, 0.8, 0.8], [1, 2, 0.8, 0.2, 0.1, 0.4, 0.4],
|
||||
[1, 2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 1, 0.2, 0.8, 0.1, 1.0, 0.3],
|
||||
[1, 3, 0.2, 0.8, 0.1, 1.0, 0.3]
|
||||
]
|
||||
|
||||
# image_id label score false_pos false_pos
|
||||
# [-1, 1, 3, -1, -1],
|
||||
# [-1, 2, 1, -1, -1]
|
||||
self.tf_pos = [[0, 1, 0.9, 1, 0], [0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1],
|
||||
[1, 1, 0.2, 1, 0], [1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0],
|
||||
[1, 3, 0.2, 0, 1]]
|
||||
|
||||
def calc_map(self, tf_pos):
|
||||
mAP = 0.0
|
||||
count = 0
|
||||
|
||||
class_pos_count = {}
|
||||
true_pos = {}
|
||||
false_pos = {}
|
||||
|
||||
def get_accumulation(pos_list):
|
||||
sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True)
|
||||
sum = 0
|
||||
accu_list = []
|
||||
for (score, count) in sorted_list:
|
||||
sum += count
|
||||
accu_list.append(sum)
|
||||
return accu_list
|
||||
|
||||
label_count = collections.Counter()
|
||||
for (label, xmin, ymin, xmax, ymax, difficult) in self.label:
|
||||
if self.evaluate_difficult:
|
||||
label_count[label] += 1
|
||||
elif not difficult:
|
||||
label_count[label] += 1
|
||||
|
||||
true_pos = collections.defaultdict(list)
|
||||
false_pos = collections.defaultdict(list)
|
||||
for (image_id, label, score, tp, fp) in tf_pos:
|
||||
true_pos[label].append([score, tp])
|
||||
false_pos[label].append([score, fp])
|
||||
|
||||
for (label, label_pos_num) in label_count.items():
|
||||
if label_pos_num == 0 or label not in true_pos:
|
||||
continue
|
||||
|
||||
label_true_pos = true_pos[label]
|
||||
label_false_pos = false_pos[label]
|
||||
|
||||
accu_tp_sum = get_accumulation(label_true_pos)
|
||||
accu_fp_sum = get_accumulation(label_false_pos)
|
||||
|
||||
precision = []
|
||||
recall = []
|
||||
|
||||
for i in range(len(accu_tp_sum)):
|
||||
precision.append(
|
||||
float(accu_tp_sum[i]) /
|
||||
float(accu_tp_sum[i] + accu_fp_sum[i]))
|
||||
recall.append(float(accu_tp_sum[i]) / label_pos_num)
|
||||
|
||||
if self.ap_type == "11point":
|
||||
max_precisions = [11.0, 0.0]
|
||||
start_idx = len(accu_tp_sum) - 1
|
||||
for j in range(10, 0, -1):
|
||||
for i in range(start_idx, 0, -1):
|
||||
if recall[i] < j / 10.0:
|
||||
start_idx = i
|
||||
if j > 0:
|
||||
max_precisions[j - 1] = max_precisions[j]
|
||||
break
|
||||
else:
|
||||
if max_precisions[j] < accu_precision[i]:
|
||||
max_precisions[j] = accu_precision[i]
|
||||
for j in range(10, 0, -1):
|
||||
mAP += max_precisions[j] / 11
|
||||
count += 1
|
||||
elif self.ap_type == "Integral":
|
||||
average_precisions = 0.0
|
||||
prev_recall = 0.0
|
||||
for i in range(len(accu_tp_sum)):
|
||||
if math.fabs(recall[i] - prev_recall) > 1e-6:
|
||||
average_precisions += precision[i] * \
|
||||
math.fabs(recall[i] - prev_recall)
|
||||
prev_recall = recall[i]
|
||||
|
||||
mAP += average_precisions
|
||||
count += 1
|
||||
|
||||
if count != 0: mAP /= count
|
||||
return mAP * 100.0
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "detection_map"
|
||||
self.set_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp):
|
||||
def init_test_case(self):
|
||||
super(TestDetectionMAPOpSkipDiff, self).init_test_case()
|
||||
|
||||
self.evaluate_difficult = False
|
||||
|
||||
self.tf_pos = [[0, 1, 0.7, 1, 0], [0, 1, 0.3, 0, 1], [1, 1, 0.2, 1, 0],
|
||||
[1, 2, 0.8, 0, 1], [1, 2, 0.1, 1, 0], [1, 3, 0.2, 0, 1]]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue