You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/cv/centerface/dependency/evaluate/eval.py

317 lines
9.8 KiB

"""
WiderFace evaluation code
author: wondervictor
mail: tianhengcheng@gmail.com
copyright@wondervictor
MIT License
Copyright (c) 2018 Vic Chan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import division
import os
import pickle
import argparse
import numpy as np
from scipy.io import loadmat
from bbox import bbox_overlaps
def get_gt_boxes(gt_dir):
""" gt dir: (wider_face_val.mat, wider_easy_val.mat, wider_medium_val.mat, wider_hard_val.mat)"""
gt_mat = loadmat(os.path.join(gt_dir, 'wider_face_val.mat')) # you own ground_truth name
hard_mat = loadmat(os.path.join(gt_dir, 'wider_hard_val.mat'))
medium_mat = loadmat(os.path.join(gt_dir, 'wider_medium_val.mat'))
easy_mat = loadmat(os.path.join(gt_dir, 'wider_easy_val.mat'))
facebox_list = gt_mat['face_bbx_list']
event_list = gt_mat['event_list']
file_list = gt_mat['file_list']
hard_gt_list = hard_mat['gt_list']
medium_gt_list = medium_mat['gt_list']
easy_gt_list = easy_mat['gt_list']
return facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list
def get_gt_boxes_from_txt(gt_path, cache_dir):
"""
Get gt boxes from binary txt file.
"""
cache_file = os.path.join(cache_dir, 'gt_cache.pkl')
if os.path.exists(cache_file):
f = open(cache_file, 'rb')
boxes = pickle.load(f)
f.close()
return boxes
f = open(gt_path, 'r')
state = 0
lines = f.readlines()
lines = list(map(lambda x: x.rstrip('\r\n'), lines))
boxes = {}
f.close()
current_boxes = []
current_name = None
for line in lines:
if state == 0 and '--' in line:
state = 1
current_name = line
continue
if state == 1:
state = 2
continue
if state == 2 and '--' in line:
state = 1
boxes[current_name] = np.array(current_boxes).astype('float32')
current_name = line
current_boxes = []
continue
if state == 2:
box = [float(x) for x in line.split(' ')[:4]]
current_boxes.append(box)
continue
f = open(cache_file, 'wb')
pickle.dump(boxes, f)
f.close()
return boxes
def read_pred_file(filepath):
with open(filepath, 'r') as f:
lines = f.readlines()
img_file = lines[0].rstrip('\n\r')
lines = lines[2:]
boxes = np.array(list(map(lambda x: [float(a) for a in x.rstrip('\r\n').split(' ')], lines))).astype('float')
return img_file.split('/')[-1], boxes
def get_preds(pred_dir):
"""Get preds"""
events = os.listdir(pred_dir)
boxes = dict()
#pbar = tqdm.tqdm(events)
pbar = events
for event in pbar:
#pbar.set_description('Reading Predictions ')
event_dir = os.path.join(pred_dir, event)
event_images = os.listdir(event_dir)
current_event = dict()
for imgtxt in event_images:
imgname, box = read_pred_file(os.path.join(event_dir, imgtxt))
current_event[imgname.rstrip('.jpg')] = box
boxes[event] = current_event
return boxes
def norm_score(pred_norm):
""" norm score
pred_norm {key: [[x1,y1,x2,y2,s]]}
"""
max_score = 0
min_score = 1
for _, k in pred_norm.items():
for _, v in k.items():
if v.size == 0:
continue
min_v = np.min(v[:, -1])
max_v = np.max(v[:, -1])
max_score = max(max_v, max_score)
min_score = min(min_v, min_score)
diff = max_score - min_score
for _, k in pred_norm.items():
for _, v in k.items():
if v.size == 0:
continue
v[:, -1] = (v[:, -1] - min_score)/diff
def image_eval(pred_eval, gt, ignore, iou_thresh):
""" single image evaluation
pred_eval: Nx5
gt: Nx4
ignore:
"""
pred_t = pred_eval.copy()
gt_t = gt.copy()
pred_recall = np.zeros(pred_t.shape[0])
recall_list = np.zeros(gt_t.shape[0])
proposal_list = np.ones(pred_t.shape[0])
pred_t[:, 2] = pred_t[:, 2] + pred_t[:, 0]
pred_t[:, 3] = pred_t[:, 3] + pred_t[:, 1]
gt_t[:, 2] = gt_t[:, 2] + gt_t[:, 0]
gt_t[:, 3] = gt_t[:, 3] + gt_t[:, 1]
overlaps = bbox_overlaps(pred_t[:, :4], gt_t)
for h in range(pred_t.shape[0]):
gt_overlap = overlaps[h]
max_overlap, max_idx = gt_overlap.max(), gt_overlap.argmax()
if max_overlap >= iou_thresh:
if ignore[max_idx] == 0:
recall_list[max_idx] = -1
proposal_list[h] = -1
elif recall_list[max_idx] == 0:
recall_list[max_idx] = 1
r_keep_index = np.where(recall_list == 1)[0]
pred_recall[h] = len(r_keep_index)
return pred_recall, proposal_list
def img_pr_info(thresh_num, pred_info, proposal_list, pred_recall):
"""
Image pr info
"""
pr_info = np.zeros((thresh_num, 2)).astype('float')
for t in range(thresh_num):
thresh = 1 - (t+1)/thresh_num
r_index = np.where(pred_info[:, 4] >= thresh)[0]
if r_index.size == 0:
pr_info[t, 0] = 0
pr_info[t, 1] = 0
else:
r_index = r_index[-1]
p_index = np.where(proposal_list[:r_index+1] == 1)[0]
pr_info[t, 0] = len(p_index)
pr_info[t, 1] = pred_recall[r_index]
return pr_info
def dataset_pr_info(thresh_num, pr_curve, count_face):
pr_curve_t = np.zeros((thresh_num, 2))
for i in range(thresh_num):
pr_curve_t[i, 0] = pr_curve[i, 1] / pr_curve[i, 0]
pr_curve_t[i, 1] = pr_curve[i, 1] / count_face
return pr_curve_t
def voc_ap(rec, prec):
"""
Voc ap calculation
"""
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def evaluation(pred_evaluation, gt_path, iou_thresh=0.4):
"""
evaluation method.
"""
print_pred = pred_evaluation
pred_evaluation = get_preds(pred_evaluation)
norm_score(pred_evaluation)
facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = get_gt_boxes(gt_path)
event_num = len(event_list)
thresh_num = 1000
setting_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
aps = []
for setting_id in range(3):
# different setting
gt_list = setting_gts[setting_id]
count_face = 0
pr_curve = np.zeros((thresh_num, 2)).astype('float')
# [hard, medium, easy]
# pbar = tqdm.tqdm(range(event_num)) # 61
pbar = range(event_num)
error_count = 0
for i in pbar:
event_name = str(event_list[i][0][0])
img_list = file_list[i][0]
pred_list = pred_evaluation[event_name]
sub_gt_list = gt_list[i][0]
gt_bbx_list = facebox_list[i][0]
for j, _ in enumerate(img_list):
try:
pred_info = pred_list[str(img_list[j][0][0])]
except KeyError:
error_count += 1
continue
gt_boxes = gt_bbx_list[j][0].astype('float')
keep_index = sub_gt_list[j][0]
count_face += len(keep_index)
if gt_boxes.size == 0 or pred_info.size == 0:
continue
ignore = np.zeros(gt_boxes.shape[0])
if keep_index.size != 0:
ignore[keep_index-1] = 1
pred_recall, proposal_list = image_eval(pred_info, gt_boxes, ignore, iou_thresh)
pr_curve += img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
propose = pr_curve[:, 0]
recall = pr_curve[:, 1]
ap = voc_ap(recall, propose)
aps.append(ap)
print("==================== Results = ====================", print_pred)
print("Easy Val AP: {}".format(aps[0]))
print("Medium Val AP: {}".format(aps[1]))
print("Hard Val AP: {}".format(aps[2]))
print("=================================================")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--pred', default='',
help='test output, txt contain box positions and scores')
parser.add_argument('-g', '--gt', default='', help='ground truth path, mat format')
args = parser.parse_args()
pred = args.pred
if os.path.isdir(pred):
evaluation(pred, args.gt)
else:
pass