You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
334 lines
14 KiB
334 lines
14 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Face Recognition eval."""
|
|
import os
|
|
import time
|
|
import math
|
|
from pprint import pformat
|
|
import numpy as np
|
|
import cv2
|
|
|
|
import mindspore.dataset.transforms.py_transforms as transforms
|
|
import mindspore.dataset.vision.py_transforms as vision
|
|
import mindspore.dataset as de
|
|
from mindspore import Tensor, context
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
from src.config import config_inference
|
|
from src.backbone.resnet import get_backbone
|
|
from src.my_logging import get_logger
|
|
|
|
devid = int(os.getenv('DEVICE_ID'))
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
|
|
|
|
|
class TxtDataset():
|
|
'''TxtDataset'''
|
|
def __init__(self, root_all, filenames):
|
|
super(TxtDataset, self).__init__()
|
|
self.imgs = []
|
|
self.labels = []
|
|
for root, filename in zip(root_all, filenames):
|
|
fin = open(filename, "r")
|
|
for line in fin:
|
|
self.imgs.append(os.path.join(root, line.strip().split(" ")[0]))
|
|
self.labels.append(line.strip())
|
|
fin.close()
|
|
|
|
def __getitem__(self, index):
|
|
try:
|
|
img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
|
|
except:
|
|
print(self.imgs[index])
|
|
raise
|
|
return img, index
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|
|
|
|
def get_all_labels(self):
|
|
return self.labels
|
|
|
|
class DistributedSampler():
|
|
'''DistributedSampler'''
|
|
def __init__(self, dataset):
|
|
self.dataset = dataset
|
|
self.num_replicas = 1
|
|
self.rank = 0
|
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
|
|
|
def __iter__(self):
|
|
indices = list(range(len(self.dataset)))
|
|
indices = indices[self.rank::self.num_replicas]
|
|
return iter(indices)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def get_dataloader(img_predix_all, img_list_all, batch_size, img_transforms):
|
|
dataset = TxtDataset(img_predix_all, img_list_all)
|
|
sampler = DistributedSampler(dataset)
|
|
dataset_column_names = ["image", "index"]
|
|
ds = de.GeneratorDataset(dataset, column_names=dataset_column_names, sampler=sampler)
|
|
ds = ds.map(input_columns=["image"], operations=img_transforms)
|
|
ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=False)
|
|
ds = ds.repeat(1)
|
|
|
|
return ds, len(dataset), dataset.get_all_labels()
|
|
|
|
def generate_test_pair(jk_list, zj_list):
|
|
'''generate_test_pair'''
|
|
file_paths = [jk_list, zj_list]
|
|
jk_dict = {}
|
|
zj_dict = {}
|
|
jk_zj_dict_list = [jk_dict, zj_dict]
|
|
for path, x_dict in zip(file_paths, jk_zj_dict_list):
|
|
with open(path, 'r') as fr:
|
|
for line in fr:
|
|
label = line.strip().split(' ')[1]
|
|
tmp = x_dict.get(label, [])
|
|
tmp.append(line.strip())
|
|
x_dict[label] = tmp
|
|
zj2jk_pairs = []
|
|
for key in jk_dict:
|
|
jk_file_list = jk_dict[key]
|
|
zj_file_list = zj_dict[key]
|
|
for zj_file in zj_file_list:
|
|
zj2jk_pairs.append([zj_file, jk_file_list])
|
|
return zj2jk_pairs
|
|
|
|
def check_minmax(args, data, min_value=0.99, max_value=1.01):
|
|
min_data = data.min()
|
|
max_data = data.max()
|
|
if np.isnan(min_data) or np.isnan(max_data):
|
|
args.logger.info('ERROR, nan happened, please check if used fp16 or other error')
|
|
raise Exception
|
|
if min_data < min_value or max_data > max_value:
|
|
args.logger.info('ERROR, min or max is out if range, range=[{}, {}], minmax=[{}, {}]'.format(
|
|
min_value, max_value, min_data, max_data))
|
|
raise Exception
|
|
|
|
def get_model(args):
|
|
'''get_model'''
|
|
net = get_backbone(args)
|
|
if args.fp16:
|
|
net.add_flags_recursive(fp16=True)
|
|
|
|
if args.weight.endswith('.ckpt'):
|
|
param_dict = load_checkpoint(args.weight)
|
|
param_dict_new = {}
|
|
for key, value in param_dict.items():
|
|
if key.startswith('moments.'):
|
|
continue
|
|
elif key.startswith('network.'):
|
|
param_dict_new[key[8:]] = value
|
|
else:
|
|
param_dict_new[key] = value
|
|
load_param_into_net(net, param_dict_new)
|
|
args.logger.info('INFO, ------------- load model success--------------')
|
|
else:
|
|
args.logger.info('ERROR, not support file:{}, please check weight in config.py'.format(args.weight))
|
|
return 0
|
|
net.set_train(False)
|
|
return net
|
|
|
|
def topk(matrix, k, axis=1):
|
|
'''topk'''
|
|
if axis == 0:
|
|
row_index = np.arange(matrix.shape[1 - axis])
|
|
topk_index = np.argpartition(-matrix, k, axis=axis)[0:k, :]
|
|
topk_data = matrix[topk_index, row_index]
|
|
topk_index_sort = np.argsort(-topk_data, axis=axis)
|
|
topk_data_sort = topk_data[topk_index_sort, row_index]
|
|
topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
|
|
else:
|
|
column_index = np.arange(matrix.shape[1 - axis])[:, None]
|
|
topk_index = np.argpartition(-matrix, k, axis=axis)[:, 0:k]
|
|
topk_data = matrix[column_index, topk_index]
|
|
topk_index_sort = np.argsort(-topk_data, axis=axis)
|
|
topk_data_sort = topk_data[column_index, topk_index_sort]
|
|
topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
|
|
return topk_data_sort, topk_index_sort
|
|
|
|
def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
|
|
'''cal_topk'''
|
|
args.logger.info('start idx:{} subprocess...'.format(idx))
|
|
correct = np.array([0] * 2)
|
|
tot = np.array([0])
|
|
|
|
zj, jk_all = zj2jk_pairs[idx]
|
|
zj_embedding = test_embedding_tot[zj]
|
|
jk_all_embedding = np.concatenate([np.expand_dims(test_embedding_tot[jk], axis=0) for jk in jk_all], axis=0)
|
|
args.logger.info('INFO, calculate top1 acc index:{}, zj_embedding shape:{}'.format(idx, zj_embedding.shape))
|
|
args.logger.info('INFO, calculate top1 acc index:{}, jk_all_embedding shape:{}'.format(idx, jk_all_embedding.shape))
|
|
|
|
test_time = time.time()
|
|
mm = np.matmul(np.expand_dims(zj_embedding, axis=0), dis_embedding_tot)
|
|
top100_jk2zj = np.squeeze(topk(mm, 100)[0], axis=0)
|
|
top100_zj2jk = topk(np.matmul(jk_all_embedding, dis_embedding_tot), 100)[0]
|
|
test_time_used = time.time() - test_time
|
|
args.logger.info('INFO, calculate top1 acc index:{}, np.matmul().top(100) time used:{:.2f}s'.format(
|
|
idx, test_time_used))
|
|
tot[0] = len(jk_all)
|
|
for i, jk in enumerate(jk_all):
|
|
jk_embedding = test_embedding_tot[jk]
|
|
similarity = np.dot(jk_embedding, zj_embedding)
|
|
if similarity > top100_jk2zj[0]:
|
|
correct[0] += 1
|
|
if similarity > top100_zj2jk[i, 0]:
|
|
correct[1] += 1
|
|
return correct, tot
|
|
|
|
def l2normalize(features):
|
|
epsilon = 1e-12
|
|
l2norm = np.sum(np.abs(features) ** 2, axis=1, keepdims=True) ** (1./2)
|
|
l2norm[np.logical_and(l2norm < 0, l2norm > -epsilon)] = -epsilon
|
|
l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon
|
|
return features/l2norm
|
|
|
|
def main(args):
|
|
if not os.path.exists(args.test_dir):
|
|
args.logger.info('ERROR, test_dir is not exists, please set test_dir in config.py.')
|
|
return 0
|
|
all_start_time = time.time()
|
|
|
|
net = get_model(args)
|
|
compile_time_used = time.time() - all_start_time
|
|
args.logger.info('INFO, graph compile finished, time used:{:.2f}s, start calculate img embedding'.
|
|
format(compile_time_used))
|
|
|
|
img_transforms = transforms.Compose([vision.ToTensor(), vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
|
|
|
|
|
#for test images
|
|
args.logger.info('INFO, start step1, calculate test img embedding, weight file = {}'.format(args.weight))
|
|
step1_start_time = time.time()
|
|
|
|
ds, img_tot, all_labels = get_dataloader(args.test_img_predix, args.test_img_list,
|
|
args.test_batch_size, img_transforms)
|
|
args.logger.info('INFO, dataset total test img:{}, total test batch:{}'.format(img_tot, ds.get_dataset_size()))
|
|
test_embedding_tot_np = np.zeros((img_tot, args.emb_size))
|
|
test_img_labels = all_labels
|
|
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
|
|
for i, data in enumerate(data_loader):
|
|
img, idxs = data["image"], data["index"]
|
|
out = net(Tensor(img)).asnumpy().astype(np.float32)
|
|
embeddings = l2normalize(out)
|
|
for batch in range(embeddings.shape[0]):
|
|
test_embedding_tot_np[idxs[batch]] = embeddings[batch]
|
|
try:
|
|
check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
|
|
except ValueError:
|
|
return 0
|
|
|
|
test_embedding_tot = {}
|
|
for idx, label in enumerate(test_img_labels):
|
|
test_embedding_tot[label] = test_embedding_tot_np[idx]
|
|
|
|
step2_start_time = time.time()
|
|
step1_time_used = step2_start_time - step1_start_time
|
|
args.logger.info('INFO, step1 finished, time used:{:.2f}s, start step2, calculate dis img embedding'.
|
|
format(step1_time_used))
|
|
|
|
# for dis images
|
|
ds_dis, img_tot, _ = get_dataloader(args.dis_img_predix, args.dis_img_list, args.dis_batch_size, img_transforms)
|
|
dis_embedding_tot_np = np.zeros((img_tot, args.emb_size))
|
|
total_batch = ds_dis.get_dataset_size()
|
|
args.logger.info('INFO, dataloader total dis img:{}, total dis batch:{}'.format(img_tot, total_batch))
|
|
start_time = time.time()
|
|
img_per_gpu = int(math.ceil(1.0 * img_tot / args.world_size))
|
|
delta_num = img_per_gpu * args.world_size - img_tot
|
|
start_idx = img_per_gpu * args.local_rank - max(0, args.local_rank - (args.world_size - delta_num))
|
|
data_loader = ds_dis.create_dict_iterator(output_numpy=True, num_epochs=1)
|
|
for idx, data in enumerate(data_loader):
|
|
img = data["image"]
|
|
out = net(Tensor(img)).asnumpy().astype(np.float32)
|
|
embeddings = l2normalize(out)
|
|
dis_embedding_tot_np[start_idx:(start_idx + embeddings.shape[0])] = embeddings
|
|
start_idx += embeddings.shape[0]
|
|
if args.local_rank % 8 == 0 and idx % args.log_interval == 0 and idx > 0:
|
|
speed = 1.0 * (args.dis_batch_size * args.log_interval * args.world_size) / (time.time() - start_time)
|
|
time_left = (total_batch - idx - 1) * args.dis_batch_size *args.world_size / speed
|
|
args.logger.info('INFO, processed [{}/{}], speed: {:.2f} img/s, left:{:.2f}s'.
|
|
format(idx, total_batch, speed, time_left))
|
|
start_time = time.time()
|
|
try:
|
|
check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
|
|
except ValueError:
|
|
return 0
|
|
|
|
step3_start_time = time.time()
|
|
step2_time_used = step3_start_time - step2_start_time
|
|
args.logger.info('INFO, step2 finished, time used:{:.2f}s, start step3, calculate top1 acc'.format(step2_time_used))
|
|
|
|
# clear npu memory
|
|
|
|
img = None
|
|
net = None
|
|
|
|
dis_embedding_tot_np = np.transpose(dis_embedding_tot_np, (1, 0))
|
|
args.logger.info('INFO, calculate top1 acc dis_embedding_tot_np shape:{}'.format(dis_embedding_tot_np.shape))
|
|
|
|
# find best match
|
|
assert len(args.test_img_list) % 2 == 0
|
|
task_num = int(len(args.test_img_list) / 2)
|
|
correct = np.array([0] * (2 * task_num))
|
|
tot = np.array([0] * task_num)
|
|
|
|
for i in range(int(len(args.test_img_list) / 2)):
|
|
jk_list = args.test_img_list[2 * i]
|
|
zj_list = args.test_img_list[2 * i + 1]
|
|
zj2jk_pairs = sorted(generate_test_pair(jk_list, zj_list))
|
|
sampler = DistributedSampler(zj2jk_pairs)
|
|
args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
|
|
for idx in sampler:
|
|
out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
|
|
correct[2 * i] += out1[0]
|
|
correct[2 * i + 1] += out1[1]
|
|
tot[i] += out2[0]
|
|
|
|
args.logger.info('local_rank={},tot={},correct={}'.format(args.local_rank, tot, correct))
|
|
|
|
step3_time_used = time.time() - step3_start_time
|
|
args.logger.info('INFO, step3 finished, time used:{:.2f}s'.format(step3_time_used))
|
|
args.logger.info('weight:{}'.format(args.weight))
|
|
|
|
for i in range(int(len(args.test_img_list) / 2)):
|
|
test_set_name = 'test_dataset'
|
|
zj2jk_acc = correct[2 * i] / tot[i]
|
|
jk2zj_acc = correct[2 * i + 1] / tot[i]
|
|
avg_acc = (zj2jk_acc + jk2zj_acc) / 2
|
|
results = '[{}]: zj2jk={:.4f}, jk2zj={:.4f}, avg={:.4f}'.format(test_set_name, zj2jk_acc, jk2zj_acc, avg_acc)
|
|
args.logger.info(results)
|
|
args.logger.info('INFO, tot time used: {:.2f}s'.format(time.time() - all_start_time))
|
|
return 0
|
|
|
|
if __name__ == '__main__':
|
|
arg = config_inference
|
|
arg.test_img_predix = [arg.test_dir, arg.test_dir]
|
|
|
|
arg.test_img_list = [os.path.join(arg.test_dir, 'lists/jk_list.txt'),
|
|
os.path.join(arg.test_dir, 'lists/zj_list.txt')]
|
|
arg.dis_img_predix = [arg.test_dir,]
|
|
arg.dis_img_list = [os.path.join(arg.test_dir, 'lists/dis_list.txt'),]
|
|
|
|
log_path = os.path.join(arg.ckpt_path, 'logs')
|
|
arg.logger = get_logger(log_path, arg.local_rank)
|
|
|
|
arg.logger.info('Config\n\n%s\n' % pformat(arg))
|
|
|
|
main(arg)
|