You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/cv/MaskedFaceRecognition/utils/metric.py

195 lines
7.0 KiB

"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid)
reid/evaluation_metrics/ranking.py. Modifications:
1) Only accepts numpy data input, no torch is involved.
1) Here results of each query can be returned.
2) In the single-gallery-shot evaluation case, the time of repeats is changed
from 10 to 100.
"""
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
from sklearn.metrics import average_precision_score
def _unique_sample(ids_dict, num):
mask = np.zeros(num, dtype=np.bool)
for _, indices in ids_dict.items():
i = np.random.choice(indices)
mask[i] = True
return mask
def cmc(
distmat,
query_ids=None,
gallery_ids=None,
query_cams=None,
gallery_cams=None,
topk=100,
separate_camera_set=False,
single_gallery_shot=False,
first_match_break=False,
average=True):
"""
Args:
distmat: numpy array with shape [num_query, num_gallery], the
pairwise distance between query and gallery samples
query_ids: numpy array with shape [num_query]
gallery_ids: numpy array with shape [num_gallery]
query_cams: numpy array with shape [num_query]
gallery_cams: numpy array with shape [num_gallery]
average: whether to average the results across queries
Returns:
If `average` is `False`:
ret: numpy array with shape [num_query, topk]
is_valid_query: numpy array with shape [num_query], containing 0's and
1's, whether each query is valid or not
If `average` is `True`:
numpy array with shape [topk]
"""
# Ensure numpy array
assert isinstance(distmat, np.ndarray)
assert isinstance(query_ids, np.ndarray)
assert isinstance(gallery_ids, np.ndarray)
# assert isinstance(query_cams, np.ndarray)
# assert isinstance(gallery_cams, np.ndarray)
# separate_camera_set=False
first_match_break = True
m, _ = distmat.shape
# Sort and find correct matches
indices = np.argsort(distmat, axis=1)
#print(indices)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute CMC for each query
ret = np.zeros([m, topk])
is_valid_query = np.zeros(m)
num_valid_queries = 0
for i in range(m):
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
if separate_camera_set:
# Filter out samples from same camera
valid = (gallery_cams[indices[i]] != query_cams[i])
if not np.any(matches[i, valid]): continue
is_valid_query[i] = 1
if single_gallery_shot:
repeat = 100
gids = gallery_ids[indices[i][valid]]
inds = np.where(valid)[0]
ids_dict = defaultdict(list)
for j, x in zip(inds, gids):
ids_dict[x].append(j)
else:
repeat = 1
for _ in range(repeat):
if single_gallery_shot:
# Randomly choose one instance for each id
sampled = (valid & _unique_sample(ids_dict, len(valid)))
index = np.nonzero(matches[i, sampled])[0]
else:
index = np.nonzero(matches[i, valid])[0]
delta = 1. / (len(index) * repeat)
for j, k in enumerate(index):
if k - j >= topk: break
if first_match_break:
ret[i, k - j] += 1
break
ret[i, k - j] += delta
num_valid_queries += 1
if num_valid_queries == 0:
raise RuntimeError("No valid query")
ret = ret.cumsum(axis=1)
if average:
return np.sum(ret, axis=0) / num_valid_queries, indices
return ret, is_valid_query, indices
def mean_ap(
distmat,
query_ids=None,
gallery_ids=None,
query_cams=None,
gallery_cams=None,
average=True):
"""
Args:
distmat: numpy array with shape [num_query, num_gallery], the
pairwise distance between query and gallery samples
query_ids: numpy array with shape [num_query]
gallery_ids: numpy array with shape [num_gallery]
query_cams: numpy array with shape [num_query]
gallery_cams: numpy array with shape [num_gallery]
average: whether to average the results across queries
Returns:
If `average` is `False`:
ret: numpy array with shape [num_query]
is_valid_query: numpy array with shape [num_query], containing 0's and
1's, whether each query is valid or not
If `average` is `True`:
a scalar
"""
# -------------------------------------------------------------------------
# The behavior of method `sklearn.average_precision` has changed since version
# 0.19.
# Version 0.18.1 has same results as Matlab evaluation code by Zhun Zhong
# (https://github.com/zhunzhong07/person-re-ranking/
# blob/master/evaluation/utils/evaluation.m) and by Liang Zheng
# (http://www.liangzheng.org/Project/project_reid.html).
# My current awkward solution is sticking to this older version.
# if cur_version != required_version:
# print('User Warning: Version {} is required for package scikit-learn, '
# 'your current version is {}. '
# 'As a result, the mAP score may not be totally correct. '
# 'You can try `pip uninstall scikit-learn` '
# 'and then `pip install scikit-learn=={}`'.format(
# required_version, cur_version, required_version))
# -------------------------------------------------------------------------
# Ensure numpy array
assert isinstance(distmat, np.ndarray)
assert isinstance(query_ids, np.ndarray)
assert isinstance(gallery_ids, np.ndarray)
# assert isinstance(query_cams, np.ndarray)
# assert isinstance(gallery_cams, np.ndarray)
m, _ = distmat.shape
# Sort and find correct matches
indices = np.argsort(distmat, axis=1)
# print("indices:", indices)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute AP for each query
aps = np.zeros(m)
is_valid_query = np.zeros(m)
for i in range(m):
# Filter out the same id and same camera
# valid = ((gallery_ids[indices[i]] != query_ids[i]) |
# (gallery_cams[indices[i]] != query_cams[i]))
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
# valid = indices[i] != i
# valid = (gallery_cams[indices[i]] != query_cams[i])
y_true = matches[i, valid]
y_score = -distmat[i][indices[i]][valid]
# y_true=y_true[0:100]
# y_score=y_score[0:100]
if not np.any(y_true): continue
is_valid_query[i] = 1
aps[i] = average_precision_score(y_true, y_score)
# if not aps:
# raise RuntimeError("No valid query")
if average:
return float(np.sum(aps)) / np.sum(is_valid_query)
return aps, is_valid_query