|
|
|
@ -19,6 +19,7 @@ Retriever Evaluation.
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
import json
|
|
|
|
|
from multiprocessing import Pool
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
@ -69,16 +70,20 @@ def eval_output(out_2, last_out, path_raw, gold_path, val, true_count):
|
|
|
|
|
return val, true_count, topk_titles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluation():
|
|
|
|
|
def evaluation(d_id):
|
|
|
|
|
"""evaluation"""
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
|
device_target='Ascend',
|
|
|
|
|
device_id=d_id,
|
|
|
|
|
save_graphs=False)
|
|
|
|
|
print('********************** loading corpus ********************** ')
|
|
|
|
|
s_lc = time.time()
|
|
|
|
|
data_generator = DataGen(config)
|
|
|
|
|
queries = read_query(config)
|
|
|
|
|
queries = read_query(config, d_id)
|
|
|
|
|
print("loading corpus time (h):", (time.time() - s_lc) / 3600)
|
|
|
|
|
print('********************** loading model ********************** ')
|
|
|
|
|
s_lm = time.time()
|
|
|
|
|
|
|
|
|
|
s_lm = time.time()
|
|
|
|
|
model_onehop_bert = ModelOneHop()
|
|
|
|
|
param_dict = load_checkpoint(config.onehop_bert_path)
|
|
|
|
|
load_param_into_net(model_onehop_bert, param_dict)
|
|
|
|
@ -90,10 +95,10 @@ def evaluation():
|
|
|
|
|
|
|
|
|
|
print("loading model time (h):", (time.time() - s_lm) / 3600)
|
|
|
|
|
print('********************** evaluation ********************** ')
|
|
|
|
|
s_tr = time.time()
|
|
|
|
|
|
|
|
|
|
f_dev = open(config.dev_path, 'rb')
|
|
|
|
|
dev_data = json.load(f_dev)
|
|
|
|
|
f_dev.close()
|
|
|
|
|
q_gold = {}
|
|
|
|
|
q_2id = {}
|
|
|
|
|
for onedata in dev_data:
|
|
|
|
@ -101,10 +106,10 @@ def evaluation():
|
|
|
|
|
q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']]
|
|
|
|
|
q_2id[onedata["question"]] = onedata['_id']
|
|
|
|
|
val, true_count, count, step = 0, 0, 0, 0
|
|
|
|
|
batch_queries = split_queries(config, queries)[:-1]
|
|
|
|
|
batch_queries = split_queries(config, queries)
|
|
|
|
|
output_path = []
|
|
|
|
|
for _, batch in enumerate(batch_queries):
|
|
|
|
|
print("###step###: ", step)
|
|
|
|
|
print("###step###: ", str(step) + "_" + str(d_id))
|
|
|
|
|
query = batch[0]
|
|
|
|
|
temp_dict = {}
|
|
|
|
|
temp_dict['q_id'] = q_2id[query]
|
|
|
|
@ -158,23 +163,36 @@ def evaluation():
|
|
|
|
|
val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count)
|
|
|
|
|
temp_dict['topk_titles'] = topk_titles
|
|
|
|
|
output_path.append(temp_dict)
|
|
|
|
|
count += 1
|
|
|
|
|
print("val:", val)
|
|
|
|
|
print("count:", count)
|
|
|
|
|
print("true count:", true_count)
|
|
|
|
|
if count:
|
|
|
|
|
print("PEM:", val / count)
|
|
|
|
|
if true_count:
|
|
|
|
|
print("true top8 PEM:", val / true_count)
|
|
|
|
|
step += 1
|
|
|
|
|
save_json(output_path, config.save_path, config.save_name)
|
|
|
|
|
print("evaluation time (h):", (time.time() - s_tr) / 3600)
|
|
|
|
|
count += 1
|
|
|
|
|
return {'val': val, 'count': count, 'true_count': true_count, 'path': output_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
t_s = time.time()
|
|
|
|
|
config = ThinkRetrieverConfig()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
|
device_target='Ascend',
|
|
|
|
|
device_id=config.device_id,
|
|
|
|
|
save_graphs=False)
|
|
|
|
|
evaluation()
|
|
|
|
|
pool = Pool(processes=config.device_num)
|
|
|
|
|
results = []
|
|
|
|
|
for device_id in range(config.device_num):
|
|
|
|
|
results.append(pool.apply_async(evaluation, (device_id,)))
|
|
|
|
|
|
|
|
|
|
print("Waiting for all subprocess done...")
|
|
|
|
|
|
|
|
|
|
pool.close()
|
|
|
|
|
pool.join()
|
|
|
|
|
|
|
|
|
|
val_all, true_count_all, count_all = 0, 0, 0
|
|
|
|
|
output_path_all = []
|
|
|
|
|
for res in results:
|
|
|
|
|
output = res.get()
|
|
|
|
|
val_all += output['val']
|
|
|
|
|
count_all += output['count']
|
|
|
|
|
true_count_all += output['true_count']
|
|
|
|
|
output_path_all += output['path']
|
|
|
|
|
print("val:", val_all)
|
|
|
|
|
print("count:", count_all)
|
|
|
|
|
print("true count:", true_count_all)
|
|
|
|
|
print("PEM:", val_all / count_all)
|
|
|
|
|
print("true top8 PEM:", val_all / true_count_all)
|
|
|
|
|
save_json(output_path_all, config.save_path, config.save_name)
|
|
|
|
|
print("evaluation time (h):", (time.time() - t_s) / 3600)
|
|
|
|
|