|
|
|
@ -21,7 +21,8 @@ __all__ = [
|
|
|
|
|
"chunk_evaluator", "sum_evaluator", "column_sum_evaluator",
|
|
|
|
|
"value_printer_evaluator", "gradient_printer_evaluator",
|
|
|
|
|
"maxid_printer_evaluator", "maxframe_printer_evaluator",
|
|
|
|
|
"seqtext_printer_evaluator", "classification_error_printer_evaluator"
|
|
|
|
|
"seqtext_printer_evaluator", "classification_error_printer_evaluator",
|
|
|
|
|
"detection_map_evaluator"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -31,10 +32,11 @@ class EvaluatorAttribute(object):
|
|
|
|
|
FOR_RANK = 1 << 2
|
|
|
|
|
FOR_PRINT = 1 << 3
|
|
|
|
|
FOR_UTILS = 1 << 4
|
|
|
|
|
FOR_DETECTION = 1 << 5
|
|
|
|
|
|
|
|
|
|
KEYS = [
|
|
|
|
|
"for_classification", "for_regression", "for_rank", "for_print",
|
|
|
|
|
"for_utils"
|
|
|
|
|
"for_utils", "for_detection"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -57,22 +59,25 @@ def evaluator(*attrs):
|
|
|
|
|
return impl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluator_base(
|
|
|
|
|
input,
|
|
|
|
|
type,
|
|
|
|
|
label=None,
|
|
|
|
|
weight=None,
|
|
|
|
|
name=None,
|
|
|
|
|
chunk_scheme=None,
|
|
|
|
|
num_chunk_types=None,
|
|
|
|
|
classification_threshold=None,
|
|
|
|
|
positive_label=None,
|
|
|
|
|
dict_file=None,
|
|
|
|
|
result_file=None,
|
|
|
|
|
num_results=None,
|
|
|
|
|
delimited=None,
|
|
|
|
|
top_k=None,
|
|
|
|
|
excluded_chunk_types=None, ):
|
|
|
|
|
def evaluator_base(input,
|
|
|
|
|
type,
|
|
|
|
|
label=None,
|
|
|
|
|
weight=None,
|
|
|
|
|
name=None,
|
|
|
|
|
chunk_scheme=None,
|
|
|
|
|
num_chunk_types=None,
|
|
|
|
|
classification_threshold=None,
|
|
|
|
|
positive_label=None,
|
|
|
|
|
dict_file=None,
|
|
|
|
|
result_file=None,
|
|
|
|
|
num_results=None,
|
|
|
|
|
delimited=None,
|
|
|
|
|
top_k=None,
|
|
|
|
|
excluded_chunk_types=None,
|
|
|
|
|
overlap_threshold=None,
|
|
|
|
|
background_id=None,
|
|
|
|
|
evaluate_difficult=None,
|
|
|
|
|
ap_type=None):
|
|
|
|
|
"""
|
|
|
|
|
Evaluator will evaluate the network status while training/testing.
|
|
|
|
|
|
|
|
|
@ -107,6 +112,14 @@ def evaluator_base(
|
|
|
|
|
:type weight: LayerOutput.
|
|
|
|
|
:param top_k: number k in top-k error rate
|
|
|
|
|
:type top_k: int
|
|
|
|
|
:param overlap_threshold: In detection tasks to filter detection results
|
|
|
|
|
:type overlap_threshold: float
|
|
|
|
|
:param background_id: Identifier of background class
|
|
|
|
|
:type background_id: int
|
|
|
|
|
:param evaluate_difficult: Whether to evaluate difficult objects
|
|
|
|
|
:type evaluate_difficult: bool
|
|
|
|
|
:param ap_type: How to calculate average persicion
|
|
|
|
|
:type ap_type: str
|
|
|
|
|
"""
|
|
|
|
|
# inputs type assertions.
|
|
|
|
|
assert classification_threshold is None or isinstance(
|
|
|
|
@ -136,7 +149,61 @@ def evaluator_base(
|
|
|
|
|
delimited=delimited,
|
|
|
|
|
num_results=num_results,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
excluded_chunk_types=excluded_chunk_types, )
|
|
|
|
|
excluded_chunk_types=excluded_chunk_types,
|
|
|
|
|
overlap_threshold=overlap_threshold,
|
|
|
|
|
background_id=background_id,
|
|
|
|
|
evaluate_difficult=evaluate_difficult,
|
|
|
|
|
ap_type=ap_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@evaluator(EvaluatorAttribute.FOR_DETECTION)
|
|
|
|
|
@wrap_name_default()
|
|
|
|
|
def detection_map_evaluator(input,
|
|
|
|
|
label,
|
|
|
|
|
overlap_threshold=0.5,
|
|
|
|
|
background_id=0,
|
|
|
|
|
evaluate_difficult=False,
|
|
|
|
|
ap_type="11point",
|
|
|
|
|
name=None):
|
|
|
|
|
"""
|
|
|
|
|
Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection.
|
|
|
|
|
|
|
|
|
|
The detection mAP Evaluator based on the output of detection_output layer counts
|
|
|
|
|
the true positive and the false positive bbox and integral them to get the
|
|
|
|
|
mAP.
|
|
|
|
|
|
|
|
|
|
The simple usage is:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
eval = detection_map_evaluator(input=det_output,label=lbl)
|
|
|
|
|
|
|
|
|
|
:param input: Input layer.
|
|
|
|
|
:type input: LayerOutput
|
|
|
|
|
:param label: Label layer.
|
|
|
|
|
:type label: LayerOutput
|
|
|
|
|
:param overlap_threshold: The bbox overlap threshold of a true positive.
|
|
|
|
|
:type overlap_threshold: float
|
|
|
|
|
:param background_id: The background class index.
|
|
|
|
|
:type background_id: int
|
|
|
|
|
:param evaluate_difficult: Whether evaluate a difficult ground truth.
|
|
|
|
|
:type evaluate_difficult: bool
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(input, list):
|
|
|
|
|
input = [input]
|
|
|
|
|
|
|
|
|
|
if label:
|
|
|
|
|
input.append(label)
|
|
|
|
|
|
|
|
|
|
evaluator_base(
|
|
|
|
|
name=name,
|
|
|
|
|
type="detection_map",
|
|
|
|
|
input=input,
|
|
|
|
|
label=label,
|
|
|
|
|
overlap_threshold=overlap_threshold,
|
|
|
|
|
background_id=background_id,
|
|
|
|
|
evaluate_difficult=evaluate_difficult,
|
|
|
|
|
ap_type=ap_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@evaluator(EvaluatorAttribute.FOR_CLASSIFICATION)
|
|
|
|
|