From 32d17598b9a225e31ad22dfe499cd07c53f92071 Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Mon, 22 Oct 2018 15:04:33 +0800 Subject: [PATCH 1/5] Refine detection mAP in metrics.py. evaluator.py throws warnings and tell users to use DetectionMAP in metrics.py, but it is wrong in metrics.py. So refine this API in metrics.py. test=develop --- python/paddle/fluid/evaluator.py | 2 +- python/paddle/fluid/metrics.py | 243 +++++++++++++++++++++++-------- 2 files changed, 183 insertions(+), 62 deletions(-) diff --git a/python/paddle/fluid/evaluator.py b/python/paddle/fluid/evaluator.py index 7a82038ff7..c84dd4bc47 100644 --- a/python/paddle/fluid/evaluator.py +++ b/python/paddle/fluid/evaluator.py @@ -316,7 +316,7 @@ class DetectionMAP(Evaluator): gt_label (Variable): The ground truth label index, which is a LoDTensor with shape [N, 1]. gt_box (Variable): The ground truth bounding box (bbox), which is a - LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax]. + LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax]. gt_difficult (Variable|None): Whether this ground truth is a difficult bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, it means all the ground truth labels are not difficult bbox. diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 0c2800dcf3..60d0e35d3a 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -478,67 +478,6 @@ class EditDistance(MetricBase): return avg_distance, avg_instance_error -class DetectionMAP(MetricBase): - """ - Calculate the detection mean average precision (mAP). - mAP is the metric to measure the accuracy of object detectors - like Faster R-CNN, SSD, etc. - It is the average of the maximum precisions at different recall values. - Please get more information from the following articles: - https://sanchom.wordpress.com/tag/average-precision/ - - https://arxiv.org/abs/1512.02325 - - The general steps are as follows: - - 1. calculate the true positive and false positive according to the input - of detection and labels. - 2. calculate mAP value, support two versions: '11 point' and 'integral'. - - Examples: - .. code-block:: python - - pred = fluid.layers.fc(input=data, size=1000, act="tanh") - batch_map = layers.detection_map( - input, - label, - class_num, - background_label, - overlap_threshold=overlap_threshold, - evaluate_difficult=evaluate_difficult, - ap_version=ap_version) - metric = fluid.metrics.DetectionMAP() - for data in train_reader(): - loss, preds, labels = exe.run(fetch_list=[cost, batch_map]) - batch_size = data[0] - metric.update(value=batch_map, weight=batch_size) - numpy_map = metric.eval() - """ - - def __init__(self, name=None): - super(DetectionMAP, self).__init__(name) - # the current map value - self.value = .0 - self.weight = .0 - - def update(self, value, weight): - if not _is_number_or_matrix_(value): - raise ValueError( - "The 'value' must be a number(int, float) or a numpy ndarray.") - if not _is_number_(weight): - raise ValueError("The 'weight' must be a number(int, float).") - self.value += value - self.weight += weight - - def eval(self): - if self.weight == 0: - raise ValueError( - "There is no data in DetectionMAP Metrics. " - "Please check layers.detection_map output has added to DetectionMAP." - ) - return self.value / self.weight - - class Auc(MetricBase): """ Auc metric adapts to the binary classification. @@ -616,3 +555,185 @@ class Auc(MetricBase): idx -= 1 return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0 + + +class DetectionMAP(object): + """ + Calculate the detection mean average precision (mAP). + + The general steps are as follows: + 1. calculate the true positive and false positive according to the input + of detection and labels. + 2. calculate mAP value, support two versions: '11 point' and 'integral'. + + Please get more information from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 + + Args: + input (Variable): The detection results, which is a LoDTensor with shape + [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax]. + gt_label (Variable): The ground truth label index, which is a LoDTensor + with shape [N, 1]. + gt_box (Variable): The ground truth bounding box (bbox), which is a + LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax]. + gt_difficult (Variable|None): Whether this ground truth is a difficult + bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, + it means all the ground truth labels are not difficult bbox. + class_num (int): The class number. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all categories will be + considered, 0 by defalut. + overlap_threshold (float): The threshold for deciding true/false + positive, 0.5 by defalut. + evaluate_difficult (bool): Whether to consider difficult ground truth + for evaluation, True by defalut. This argument does not work when + gt_difficult is None. + ap_version (string): The average precision calculation ways, it must be + 'integral' or '11point'. Please check + https://sanchom.wordpress.com/tag/average-precision/ for details. + - 11point: the 11-point interpolated average precision. + - integral: the natural integral of the precision-recall curve. + + Examples: + .. code-block:: python + + exe = fluid.executor(place) + map_evaluator = fluid.Evaluator.DetectionMAP(input, + gt_label, gt_box, gt_difficult) + cur_map, accum_map = map_evaluator.get_map_var() + fetch = [cost, cur_map, accum_map] + for epoch in PASS_NUM: + map_evaluator.reset(exe) + for data in batches: + loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch) + + In the above example: + + 'cur_map_v' is the mAP of current mini-batch. + 'accum_map_v' is the accumulative mAP of one pass. + """ + + def __init__(self, + input, + gt_label, + gt_box, + gt_difficult=None, + class_num=None, + background_label=0, + overlap_threshold=0.5, + evaluate_difficult=True, + ap_version='integral'): + from . import layers + from .layer_helper import LayerHelper + from .initializer import Constant + + self.helper = LayerHelper('map_eval') + gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) + if gt_difficult: + gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype) + label = layers.concat([gt_label, gt_difficult, gt_box], axis=1) + else: + label = layers.concat([gt_label, gt_box], axis=1) + + # calculate mean average precision (mAP) of current mini-batch + map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + ap_version=ap_version) + + states = [] + states.append( + self._create_state( + dtype='int32', shape=None, suffix='accum_pos_count')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_true_pos')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_false_pos')) + var = self._create_state(dtype='int32', shape=[1], suffix='has_state') + self.helper.set_variable_initializer( + var, initializer=Constant(value=int(0))) + self.has_state = var + + # calculate accumulative mAP + accum_map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + has_state=self.has_state, + input_states=states, + out_states=states, + ap_version=ap_version) + + layers.fill_constant( + shape=self.has_state.shape, + value=1, + dtype=self.has_state.dtype, + out=self.has_state) + + self.cur_map = map + self.accum_map = accum_map + + def _create_state(self, suffix, dtype, shape): + """ + Create state variable. + Args: + suffix(str): the state suffix. + dtype(str|core.VarDesc.VarType): the state data type + shape(tuple|list): the shape of state + Returns: State variable + """ + from . import unique_name + state = self.helper.create_variable( + name="_".join([unique_name.generate(self.helper.name), suffix]), + persistable=True, + dtype=dtype, + shape=shape) + return state + + def get_map_var(self): + """ + Returns: mAP variable of current mini-batch and + accumulative mAP variable cross mini-batches. + """ + return self.cur_map, self.accum_map + + def reset(self, executor, reset_program=None): + """ + reset metric states at the begin of each pass/user specified batch. + + Args: + executor(Executor|ParallelExecutor): a executor for executing + the reset_program. + reset_program(Program|None): a single Program for reset process. + If None, will create a Program. + """ + from .framework import Program, Variable, program_guard + from . import layers + + def _clone_var_(block, var): + assert isinstance(var, Variable) + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=var.persistable) + + if reset_program is None: + reset_program = Program() + with program_guard(main_program=reset_program): + var = _clone_var_(reset_program.current_block(), self.has_state) + layers.fill_constant( + shape=var.shape, value=0, dtype=var.dtype, out=var) + executor.run(reset_program) From 2939fc9f038f3b128e6247c5ce8bf88dfdefe830 Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Mon, 22 Oct 2018 20:16:41 +0800 Subject: [PATCH 2/5] test=develop --- python/paddle/fluid/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 60d0e35d3a..f409da90c1 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -610,8 +610,8 @@ class DetectionMAP(object): In the above example: - 'cur_map_v' is the mAP of current mini-batch. - 'accum_map_v' is the accumulative mAP of one pass. + 'cur_map_v' is the mAP of current mini-batch. + 'accum_map_v' is the accumulative mAP of one pass. """ def __init__(self, From 60229c1e3cc3bac9c7e4c7a341dd65026ce0bd92 Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Wed, 24 Oct 2018 20:20:07 +0800 Subject: [PATCH 3/5] Follow comments. test=develop --- python/paddle/fluid/metrics.py | 16 +++--- .../fluid/tests/unittests/test_metrics.py | 49 +++++++++++++++++++ 2 files changed, 56 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_metrics.py diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index f409da90c1..42d9aeb67f 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -13,8 +13,6 @@ # limitations under the License. """ Fluid Metrics - -The metrics are accomplished via Python natively. """ from __future__ import print_function @@ -24,6 +22,12 @@ import copy import warnings import six +from .layer_helper import LayerHelper +from .initializer import Constant +from . import unique_name +from .framework import Program, Variable, program_guard +from . import layers + __all__ = [ 'MetricBase', 'CompositeMetric', @@ -598,7 +602,7 @@ class DetectionMAP(object): Examples: .. code-block:: python - exe = fluid.executor(place) + exe = fluid.Executor(place) map_evaluator = fluid.Evaluator.DetectionMAP(input, gt_label, gt_box, gt_difficult) cur_map, accum_map = map_evaluator.get_map_var() @@ -624,9 +628,6 @@ class DetectionMAP(object): overlap_threshold=0.5, evaluate_difficult=True, ap_version='integral'): - from . import layers - from .layer_helper import LayerHelper - from .initializer import Constant self.helper = LayerHelper('map_eval') gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) @@ -692,7 +693,6 @@ class DetectionMAP(object): shape(tuple|list): the shape of state Returns: State variable """ - from . import unique_name state = self.helper.create_variable( name="_".join([unique_name.generate(self.helper.name), suffix]), persistable=True, @@ -717,8 +717,6 @@ class DetectionMAP(object): reset_program(Program|None): a single Program for reset process. If None, will create a Program. """ - from .framework import Program, Variable, program_guard - from . import layers def _clone_var_(block, var): assert isinstance(var, Variable) diff --git a/python/paddle/fluid/tests/unittests/test_metrics.py b/python/paddle/fluid/tests/unittests/test_metrics.py new file mode 100644 index 0000000000..ec27884cae --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_metrics.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +from paddle.fluid.framework import Program, program_guard + + +class TestMetricsDetectionMap(unittest.TestCase): + def test_detection_map(self): + program = fluid.Program() + with program_guard(program): + detect_res = fluid.layers.data( + name='detect_res', + shape=[10, 6], + append_batch_size=False, + dtype='float32') + label = fluid.layers.data( + name='label', + shape=[10, 1], + append_batch_size=False, + dtype='float32') + box = fluid.layers.data( + name='bbox', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + map_eval = fluid.metrics.DetectionMAP( + detect_res, label, box, class_num=21) + cur_map, accm_map = map_eval.get_map_var() + self.assertIsNotNone(cur_map) + self.assertIsNotNone(accm_map) + print(str(program)) + + +if __name__ == '__main__': + unittest.main() From 80ee069b9de7eb0a7faf16911e708a9895147d4c Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Wed, 24 Oct 2018 22:02:30 +0800 Subject: [PATCH 4/5] Refince comments. test=develop --- python/paddle/fluid/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 42d9aeb67f..5e03caa603 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -712,7 +712,7 @@ class DetectionMAP(object): reset metric states at the begin of each pass/user specified batch. Args: - executor(Executor|ParallelExecutor): a executor for executing + executor(Executor): a executor for executing the reset_program. reset_program(Program|None): a single Program for reset process. If None, will create a Program. From 5ca9c2d04fc803d3057b7b2662e58189d3ff22e7 Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Thu, 25 Oct 2018 18:04:37 +0800 Subject: [PATCH 5/5] Update code test=develop --- python/paddle/fluid/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 5e03caa603..25d43be3b7 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -709,7 +709,7 @@ class DetectionMAP(object): def reset(self, executor, reset_program=None): """ - reset metric states at the begin of each pass/user specified batch. + Reset metric states at the begin of each pass/user specified batch. Args: executor(Executor): a executor for executing