|
|
@ -109,7 +109,7 @@ def generate_test_pair(jk_list, zj_list):
|
|
|
|
zj2jk_pairs.append([zj_file, jk_file_list])
|
|
|
|
zj2jk_pairs.append([zj_file, jk_file_list])
|
|
|
|
return zj2jk_pairs
|
|
|
|
return zj2jk_pairs
|
|
|
|
|
|
|
|
|
|
|
|
def check_minmax(data, min_value=0.99, max_value=1.01):
|
|
|
|
def check_minmax(args, data, min_value=0.99, max_value=1.01):
|
|
|
|
min_data = data.min()
|
|
|
|
min_data = data.min()
|
|
|
|
max_data = data.max()
|
|
|
|
max_data = data.max()
|
|
|
|
if np.isnan(min_data) or np.isnan(max_data):
|
|
|
|
if np.isnan(min_data) or np.isnan(max_data):
|
|
|
@ -162,7 +162,7 @@ def topk(matrix, k, axis=1):
|
|
|
|
topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
|
|
|
|
topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
|
|
|
|
return topk_data_sort, topk_index_sort
|
|
|
|
return topk_data_sort, topk_index_sort
|
|
|
|
|
|
|
|
|
|
|
|
def cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
|
|
|
|
def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
|
|
|
|
'''cal_topk'''
|
|
|
|
'''cal_topk'''
|
|
|
|
args.logger.info('start idx:{} subprocess...'.format(idx))
|
|
|
|
args.logger.info('start idx:{} subprocess...'.format(idx))
|
|
|
|
correct = np.array([0] * 2)
|
|
|
|
correct = np.array([0] * 2)
|
|
|
@ -230,7 +230,7 @@ def main(args):
|
|
|
|
for batch in range(embeddings.shape[0]):
|
|
|
|
for batch in range(embeddings.shape[0]):
|
|
|
|
test_embedding_tot_np[idxs[batch]] = embeddings[batch]
|
|
|
|
test_embedding_tot_np[idxs[batch]] = embeddings[batch]
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
check_minmax(np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
|
|
|
|
check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
|
|
|
|
except ValueError:
|
|
|
|
except ValueError:
|
|
|
|
return 0
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
@ -266,7 +266,7 @@ def main(args):
|
|
|
|
format(idx, total_batch, speed, time_left))
|
|
|
|
format(idx, total_batch, speed, time_left))
|
|
|
|
start_time = time.time()
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
check_minmax(np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
|
|
|
|
check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
|
|
|
|
except ValueError:
|
|
|
|
except ValueError:
|
|
|
|
return 0
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
@ -295,7 +295,7 @@ def main(args):
|
|
|
|
sampler = DistributedSampler(zj2jk_pairs)
|
|
|
|
sampler = DistributedSampler(zj2jk_pairs)
|
|
|
|
args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
|
|
|
|
args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
|
|
|
|
for idx in sampler:
|
|
|
|
for idx in sampler:
|
|
|
|
out1, out2 = cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
|
|
|
|
out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
|
|
|
|
correct[2 * i] += out1[0]
|
|
|
|
correct[2 * i] += out1[0]
|
|
|
|
correct[2 * i + 1] += out1[1]
|
|
|
|
correct[2 * i + 1] += out1[1]
|
|
|
|
tot[i] += out2[0]
|
|
|
|
tot[i] += out2[0]
|
|
|
|