fix detection_map. test=develop (#22705)

revert-22710-feature/integrated_ps_api
Kaipeng Deng 5 years ago committed by GitHub
parent ee8b22fbec
commit ebc7ffc300
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -420,8 +420,11 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first;
int label_num_pos = it->second;
if (label_num_pos == background_label ||
true_pos.find(label) == true_pos.end()) {
if (label_num_pos == background_label) {
continue;
}
if (true_pos.find(label) == true_pos.end()) {
count++;
continue;
}
auto label_true_pos = true_pos.find(label)->second;

@ -181,7 +181,10 @@ class TestDetectionMAPOp(OpTest):
false_pos[label].append([score, fp])
for (label, label_pos_num) in six.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue
if label_pos_num == 0: continue
if label not in true_pos:
count += 1
continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
@ -281,5 +284,30 @@ class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp):
self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]]
class TestDetectionMAPOp11PointWithClassNoTP(TestDetectionMAPOp):
def init_test_case(self):
self.overlap_threshold = 0.3
self.evaluate_difficult = True
self.ap_type = "11point"
self.label_lod = [[2]]
# label difficult xmin ymin xmax ymax
self.label = [[2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]]
# label score xmin ymin xmax ymax difficult
self.detect_lod = [[1]]
self.detect = [[1, 0.2, 0.8, 0.1, 1.0, 0.3]]
# label score true_pos false_pos
self.tf_pos_lod = [[3, 4]]
self.tf_pos = [[1, 0.2, 1, 0]]
self.class_pos_count = []
self.true_pos_lod = [[]]
self.true_pos = [[]]
self.false_pos_lod = [[]]
self.false_pos = [[]]
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save