add tprr 8p version

pull/13838/head
huenrui 5 years ago committed by zhanke
parent b1043bcf55
commit 676c219bc4

@ -19,6 +19,7 @@ Retriever Evaluation.
import time import time
import json import json
from multiprocessing import Pool
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor
@ -69,16 +70,20 @@ def eval_output(out_2, last_out, path_raw, gold_path, val, true_count):
return val, true_count, topk_titles return val, true_count, topk_titles
def evaluation(): def evaluation(d_id):
"""evaluation""" """evaluation"""
context.set_context(mode=context.GRAPH_MODE,
device_target='Ascend',
device_id=d_id,
save_graphs=False)
print('********************** loading corpus ********************** ') print('********************** loading corpus ********************** ')
s_lc = time.time() s_lc = time.time()
data_generator = DataGen(config) data_generator = DataGen(config)
queries = read_query(config) queries = read_query(config, d_id)
print("loading corpus time (h):", (time.time() - s_lc) / 3600) print("loading corpus time (h):", (time.time() - s_lc) / 3600)
print('********************** loading model ********************** ') print('********************** loading model ********************** ')
s_lm = time.time()
s_lm = time.time()
model_onehop_bert = ModelOneHop() model_onehop_bert = ModelOneHop()
param_dict = load_checkpoint(config.onehop_bert_path) param_dict = load_checkpoint(config.onehop_bert_path)
load_param_into_net(model_onehop_bert, param_dict) load_param_into_net(model_onehop_bert, param_dict)
@ -90,10 +95,10 @@ def evaluation():
print("loading model time (h):", (time.time() - s_lm) / 3600) print("loading model time (h):", (time.time() - s_lm) / 3600)
print('********************** evaluation ********************** ') print('********************** evaluation ********************** ')
s_tr = time.time()
f_dev = open(config.dev_path, 'rb') f_dev = open(config.dev_path, 'rb')
dev_data = json.load(f_dev) dev_data = json.load(f_dev)
f_dev.close()
q_gold = {} q_gold = {}
q_2id = {} q_2id = {}
for onedata in dev_data: for onedata in dev_data:
@ -101,10 +106,10 @@ def evaluation():
q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']] q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']]
q_2id[onedata["question"]] = onedata['_id'] q_2id[onedata["question"]] = onedata['_id']
val, true_count, count, step = 0, 0, 0, 0 val, true_count, count, step = 0, 0, 0, 0
batch_queries = split_queries(config, queries)[:-1] batch_queries = split_queries(config, queries)
output_path = [] output_path = []
for _, batch in enumerate(batch_queries): for _, batch in enumerate(batch_queries):
print("###step###: ", step) print("###step###: ", str(step) + "_" + str(d_id))
query = batch[0] query = batch[0]
temp_dict = {} temp_dict = {}
temp_dict['q_id'] = q_2id[query] temp_dict['q_id'] = q_2id[query]
@ -158,23 +163,36 @@ def evaluation():
val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count) val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count)
temp_dict['topk_titles'] = topk_titles temp_dict['topk_titles'] = topk_titles
output_path.append(temp_dict) output_path.append(temp_dict)
count += 1
print("val:", val)
print("count:", count)
print("true count:", true_count)
if count:
print("PEM:", val / count)
if true_count:
print("true top8 PEM:", val / true_count)
step += 1 step += 1
save_json(output_path, config.save_path, config.save_name) count += 1
print("evaluation time (h):", (time.time() - s_tr) / 3600) return {'val': val, 'count': count, 'true_count': true_count, 'path': output_path}
if __name__ == "__main__": if __name__ == "__main__":
t_s = time.time()
config = ThinkRetrieverConfig() config = ThinkRetrieverConfig()
context.set_context(mode=context.GRAPH_MODE, pool = Pool(processes=config.device_num)
device_target='Ascend', results = []
device_id=config.device_id, for device_id in range(config.device_num):
save_graphs=False) results.append(pool.apply_async(evaluation, (device_id,)))
evaluation()
print("Waiting for all subprocess done...")
pool.close()
pool.join()
val_all, true_count_all, count_all = 0, 0, 0
output_path_all = []
for res in results:
output = res.get()
val_all += output['val']
count_all += output['count']
true_count_all += output['true_count']
output_path_all += output['path']
print("val:", val_all)
print("count:", count_all)
print("true count:", true_count_all)
print("PEM:", val_all / count_all)
print("true top8 PEM:", val_all / true_count_all)
save_json(output_path_all, config.save_path, config.save_name)
print("evaluation time (h):", (time.time() - t_s) / 3600)

@ -31,7 +31,7 @@ def ThinkRetrieverConfig():
parser.add_argument("--topk", type=int, default=8, help="top num") parser.add_argument("--topk", type=int, default=8, help="top num")
parser.add_argument("--onehop_num", type=int, default=8, help="onehop num") parser.add_argument("--onehop_num", type=int, default=8, help="onehop num")
parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--device_num", type=int, default=8, help="device num")
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output') parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
parser.add_argument("--save_path", type=str, default='../', help='path of output') parser.add_argument("--save_path", type=str, default='../', help='path of output')
parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path") parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path")
@ -43,4 +43,5 @@ def ThinkRetrieverConfig():
parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path") parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path")
parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path") parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path")
parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path") parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path")
parser.add_argument("--q_path", type=str, default="../queries", help="queries data path")
return parser.parse_args() return parser.parse_args()

@ -53,6 +53,9 @@ class DataGen:
data_db = pkl.load(f_wiki, encoding="gbk") data_db = pkl.load(f_wiki, encoding="gbk")
dev_data = json.load(f_train) dev_data = json.load(f_train)
q_doc_text = pkl.load(f_doc, encoding='gbk') q_doc_text = pkl.load(f_doc, encoding='gbk')
f_wiki.close()
f_train.close()
f_doc.close()
return data_db, dev_data, q_doc_text return data_db, dev_data, q_doc_text
def process_data(self): def process_data(self):

@ -28,13 +28,10 @@ def normalize(text):
return text[0].capitalize() + text[1:] return text[0].capitalize() + text[1:]
def read_query(config): def read_query(config, device_id):
"""get query data""" """get query data"""
with open(config.dev_data_path, 'rb') as f: with open(config.q_path + str(device_id), 'rb') as f:
temp_dic = pkl.load(f, encoding='gbk') queries = pkl.load(f, encoding='gbk')
queries = []
for item in temp_dic:
queries.append(temp_dic[item]["query"])
return queries return queries

Loading…
Cancel
Save