!10270 Fix FaceRecognition net eval fail

From: @zhanghuiyao
Reviewed-by: @linqingke,@oacjiewen
Signed-off-by: @linqingke
pull/10270/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5a1b09c852

@ -109,7 +109,7 @@ def generate_test_pair(jk_list, zj_list):
zj2jk_pairs.append([zj_file, jk_file_list]) zj2jk_pairs.append([zj_file, jk_file_list])
return zj2jk_pairs return zj2jk_pairs
def check_minmax(data, min_value=0.99, max_value=1.01): def check_minmax(args, data, min_value=0.99, max_value=1.01):
min_data = data.min() min_data = data.min()
max_data = data.max() max_data = data.max()
if np.isnan(min_data) or np.isnan(max_data): if np.isnan(min_data) or np.isnan(max_data):
@ -162,7 +162,7 @@ def topk(matrix, k, axis=1):
topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort] topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
return topk_data_sort, topk_index_sort return topk_data_sort, topk_index_sort
def cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot): def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
'''cal_topk''' '''cal_topk'''
args.logger.info('start idx:{} subprocess...'.format(idx)) args.logger.info('start idx:{} subprocess...'.format(idx))
correct = np.array([0] * 2) correct = np.array([0] * 2)
@ -230,7 +230,7 @@ def main(args):
for batch in range(embeddings.shape[0]): for batch in range(embeddings.shape[0]):
test_embedding_tot_np[idxs[batch]] = embeddings[batch] test_embedding_tot_np[idxs[batch]] = embeddings[batch]
try: try:
check_minmax(np.linalg.norm(test_embedding_tot_np, ord=2, axis=1)) check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
except ValueError: except ValueError:
return 0 return 0
@ -266,7 +266,7 @@ def main(args):
format(idx, total_batch, speed, time_left)) format(idx, total_batch, speed, time_left))
start_time = time.time() start_time = time.time()
try: try:
check_minmax(np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1)) check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
except ValueError: except ValueError:
return 0 return 0
@ -295,7 +295,7 @@ def main(args):
sampler = DistributedSampler(zj2jk_pairs) sampler = DistributedSampler(zj2jk_pairs)
args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler))) args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
for idx in sampler: for idx in sampler:
out1, out2 = cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np) out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
correct[2 * i] += out1[0] correct[2 * i] += out1[0]
correct[2 * i + 1] += out1[1] correct[2 * i + 1] += out1[1]
tot[i] += out2[0] tot[i] += out2[0]

Loading…
Cancel
Save