!13375 add reranker and reader into modelzoo/research/tprr
From: @huenrui Reviewed-by: @oacjiewen,@guoqi1024 Signed-off-by:pull/13375/MERGE
commit
5313abfb66
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""main file"""
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data
|
||||||
|
from src.reranker_eval import rerank
|
||||||
|
from src.reader_eval import read
|
||||||
|
from src.hotpot_evaluate_v1 import hotpotqa_eval
|
||||||
|
from src.build_reranker_data import get_rerank_data
|
||||||
|
|
||||||
|
|
||||||
|
def rerank_and_retriever_eval():
|
||||||
|
"""main function"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
parser = get_parse()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.get_reranker_data:
|
||||||
|
get_rerank_data(args)
|
||||||
|
|
||||||
|
if args.run_reranker:
|
||||||
|
rerank(args)
|
||||||
|
|
||||||
|
if args.cal_reranker_metrics:
|
||||||
|
total_top1_pem, _, _ = \
|
||||||
|
cal_reranker_metrics(dev_gold_file=args.dev_gold_file, rerank_result_file=args.rerank_result_file)
|
||||||
|
print(f"total top1 pem: {total_top1_pem}")
|
||||||
|
|
||||||
|
if args.select_reader_data:
|
||||||
|
select_reader_dev_data(args)
|
||||||
|
|
||||||
|
if args.run_reader:
|
||||||
|
read(args)
|
||||||
|
|
||||||
|
if args.cal_reader_metrics:
|
||||||
|
metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_file)
|
||||||
|
for k in metrics:
|
||||||
|
print(f"{k}: {metrics[k]}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
rerank_and_retriever_eval()
|
@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# eval script
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp *.sh ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
env > env.log
|
||||||
|
echo "start evaluation"
|
||||||
|
|
||||||
|
python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics > log_reranker_and_reader.txt 2>&1 &
|
||||||
|
|
||||||
|
cd ..
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,153 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""hotpotqa evaluate script"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
from collections import Counter
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_answer(s):
|
||||||
|
"""normalize answer"""
|
||||||
|
def remove_articles(text):
|
||||||
|
"""remove articles"""
|
||||||
|
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||||
|
|
||||||
|
def white_space_fix(text):
|
||||||
|
"""fix whitespace"""
|
||||||
|
return ' '.join(text.split())
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
"""remove punctuation from text"""
|
||||||
|
exclude = set(string.punctuation)
|
||||||
|
return ''.join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
"""lower text"""
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
|
||||||
|
def f1_score(prediction, ground_truth):
|
||||||
|
"""calculate f1 score"""
|
||||||
|
normalized_prediction = normalize_answer(prediction)
|
||||||
|
normalized_ground_truth = normalize_answer(ground_truth)
|
||||||
|
|
||||||
|
ZERO_METRIC = (0, 0, 0)
|
||||||
|
|
||||||
|
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
|
||||||
|
return ZERO_METRIC
|
||||||
|
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
|
||||||
|
return ZERO_METRIC
|
||||||
|
|
||||||
|
prediction_tokens = normalized_prediction.split()
|
||||||
|
ground_truth_tokens = normalized_ground_truth.split()
|
||||||
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if num_same == 0:
|
||||||
|
return ZERO_METRIC
|
||||||
|
precision = 1.0 * num_same / len(prediction_tokens)
|
||||||
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
return f1, precision, recall
|
||||||
|
|
||||||
|
|
||||||
|
def exact_match_score(prediction, ground_truth):
|
||||||
|
"""calculate exact match score"""
|
||||||
|
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||||
|
|
||||||
|
|
||||||
|
def update_answer(metrics, prediction, gold):
|
||||||
|
"""update answer"""
|
||||||
|
em = exact_match_score(prediction, gold)
|
||||||
|
f1, prec, recall = f1_score(prediction, gold)
|
||||||
|
metrics['em'] += float(em)
|
||||||
|
metrics['f1'] += f1
|
||||||
|
metrics['prec'] += prec
|
||||||
|
metrics['recall'] += recall
|
||||||
|
return em, prec, recall
|
||||||
|
|
||||||
|
|
||||||
|
def update_sp(metrics, prediction, gold):
|
||||||
|
"""update supporting sentences"""
|
||||||
|
cur_sp_pred = set(map(tuple, prediction))
|
||||||
|
gold_sp_pred = set(map(tuple, gold))
|
||||||
|
tp, fp, fn = 0, 0, 0
|
||||||
|
for e in cur_sp_pred:
|
||||||
|
if e in gold_sp_pred:
|
||||||
|
tp += 1
|
||||||
|
else:
|
||||||
|
fp += 1
|
||||||
|
for e in gold_sp_pred:
|
||||||
|
if e not in cur_sp_pred:
|
||||||
|
fn += 1
|
||||||
|
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
|
||||||
|
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
|
||||||
|
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
|
||||||
|
em = 1.0 if fp + fn == 0 else 0.0
|
||||||
|
metrics['sp_em'] += em
|
||||||
|
metrics['sp_f1'] += f1
|
||||||
|
metrics['sp_prec'] += prec
|
||||||
|
metrics['sp_recall'] += recall
|
||||||
|
return em, prec, recall
|
||||||
|
|
||||||
|
|
||||||
|
def hotpotqa_eval(prediction_file, gold_file):
|
||||||
|
"""hotpotqa evaluate function"""
|
||||||
|
with open(prediction_file) as f:
|
||||||
|
prediction = json.load(f)
|
||||||
|
with open(gold_file) as f:
|
||||||
|
gold = json.load(f)
|
||||||
|
|
||||||
|
metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
|
||||||
|
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
|
||||||
|
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
|
||||||
|
for dp in gold:
|
||||||
|
cur_id = dp['_id']
|
||||||
|
can_eval_joint = True
|
||||||
|
if cur_id not in prediction['answer']:
|
||||||
|
print('missing answer {}'.format(cur_id))
|
||||||
|
can_eval_joint = False
|
||||||
|
else:
|
||||||
|
em, prec, recall = update_answer(
|
||||||
|
metrics, prediction['answer'][cur_id], dp['answer'])
|
||||||
|
if cur_id not in prediction['sp']:
|
||||||
|
print('missing sp fact {}'.format(cur_id))
|
||||||
|
can_eval_joint = False
|
||||||
|
else:
|
||||||
|
sp_em, sp_prec, sp_recall = update_sp(
|
||||||
|
metrics, prediction['sp'][cur_id], dp['supporting_facts'])
|
||||||
|
|
||||||
|
if can_eval_joint:
|
||||||
|
joint_prec = prec * sp_prec
|
||||||
|
joint_recall = recall * sp_recall
|
||||||
|
if joint_prec + joint_recall > 0:
|
||||||
|
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
|
||||||
|
else:
|
||||||
|
joint_f1 = 0.
|
||||||
|
joint_em = em * sp_em
|
||||||
|
|
||||||
|
metrics['joint_em'] += joint_em
|
||||||
|
metrics['joint_f1'] += joint_f1
|
||||||
|
metrics['joint_prec'] += joint_prec
|
||||||
|
metrics['joint_recall'] += joint_recall
|
||||||
|
|
||||||
|
num = len(gold)
|
||||||
|
for k in metrics:
|
||||||
|
metrics[k] /= num
|
||||||
|
|
||||||
|
return metrics
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Reader model"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.ops import BatchMatMul
|
||||||
|
from mindspore import ops
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
from src.reader_albert_xxlarge import Reader_Albert
|
||||||
|
from src.reader_downstream import Reader_Downstream
|
||||||
|
|
||||||
|
|
||||||
|
dst_type = mstype.float16
|
||||||
|
dst_type2 = mstype.float32
|
||||||
|
|
||||||
|
|
||||||
|
class Reader(nn.Cell):
|
||||||
|
"""Reader model"""
|
||||||
|
def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
|
||||||
|
"""init function"""
|
||||||
|
super(Reader, self).__init__(auto_prefix=False)
|
||||||
|
|
||||||
|
self.encoder = Reader_Albert(batch_size)
|
||||||
|
param_dict = load_checkpoint(encoder_ck_file)
|
||||||
|
not_load_params = load_param_into_net(self.encoder, param_dict)
|
||||||
|
print(f"not loaded: {not_load_params}")
|
||||||
|
|
||||||
|
self.downstream = Reader_Downstream()
|
||||||
|
param_dict = load_checkpoint(downstream_ck_file)
|
||||||
|
not_load_params = load_param_into_net(self.downstream, param_dict)
|
||||||
|
print(f"not loaded: {not_load_params}")
|
||||||
|
|
||||||
|
self.bmm = BatchMatMul()
|
||||||
|
|
||||||
|
def construct(self, input_ids, attn_mask, token_type_ids,
|
||||||
|
context_mask, square_mask, packing_mask, cache_mask,
|
||||||
|
para_start_mapping, sent_end_mapping):
|
||||||
|
"""construct function"""
|
||||||
|
state = self.encoder(attn_mask, input_ids, token_type_ids)
|
||||||
|
|
||||||
|
para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D]
|
||||||
|
sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D]
|
||||||
|
|
||||||
|
q_type, start, end, para_logit, sent_logit = self.downstream(ops.Cast()(para_state, dst_type2),
|
||||||
|
ops.Cast()(sent_state, dst_type2),
|
||||||
|
state,
|
||||||
|
context_mask)
|
||||||
|
|
||||||
|
outer = start[:, :, None] + end[:, None]
|
||||||
|
|
||||||
|
outer_mask = cache_mask
|
||||||
|
outer_mask = square_mask * outer_mask[None]
|
||||||
|
outer = outer - 1e30 * (1 - outer_mask)
|
||||||
|
outer = outer - 1e30 * packing_mask[:, :, None]
|
||||||
|
max_row = ops.ReduceMax()(outer, 2)
|
||||||
|
y1 = ops.Argmax()(max_row)
|
||||||
|
max_col = ops.ReduceMax()(outer, 1)
|
||||||
|
y2 = ops.Argmax()(max_col)
|
||||||
|
|
||||||
|
return start, end, q_type, para_logit, sent_logit, y1, y2
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,213 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""downstream Model for reader"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import nn, ops
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
dst_type = mstype.float16
|
||||||
|
dst_type2 = mstype.float32
|
||||||
|
|
||||||
|
|
||||||
|
class Module15(nn.Cell):
|
||||||
|
"""module of reader downstream"""
|
||||||
|
def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
|
||||||
|
"""init function"""
|
||||||
|
super(Module15, self).__init__()
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
|
||||||
|
name=None)
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
|
||||||
|
self.relu_2 = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
|
||||||
|
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
|
||||||
|
opt_relu_2 = self.relu_2(opt_add_1)
|
||||||
|
return opt_relu_2
|
||||||
|
|
||||||
|
|
||||||
|
class NormModule(nn.Cell):
|
||||||
|
"""Normalization module of reader downstream"""
|
||||||
|
def __init__(self, mul_8_w_shape, add_9_bias_shape):
|
||||||
|
"""init function"""
|
||||||
|
super(NormModule, self).__init__()
|
||||||
|
self.reducemean_0 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.sub_1 = P.Sub()
|
||||||
|
self.sub_2 = P.Sub()
|
||||||
|
self.pow_3 = P.Pow()
|
||||||
|
self.pow_3_input_weight = 2.0
|
||||||
|
self.reducemean_4 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.add_5 = P.Add()
|
||||||
|
self.add_5_bias = 9.999999960041972e-13
|
||||||
|
self.sqrt_6 = P.Sqrt()
|
||||||
|
self.div_7 = P.Div()
|
||||||
|
self.mul_8 = P.Mul()
|
||||||
|
self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, mul_8_w_shape).astype(np.float32)), name=None)
|
||||||
|
self.add_9 = P.Add()
|
||||||
|
self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, add_9_bias_shape).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_reducemean_0 = self.reducemean_0(x, -1)
|
||||||
|
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
|
||||||
|
opt_sub_2 = self.sub_2(x, opt_reducemean_0)
|
||||||
|
opt_pow_3 = self.pow_3(opt_sub_1, self.pow_3_input_weight)
|
||||||
|
opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1)
|
||||||
|
opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias)
|
||||||
|
opt_sqrt_6 = self.sqrt_6(opt_add_5)
|
||||||
|
opt_div_7 = self.div_7(opt_sub_2, opt_sqrt_6)
|
||||||
|
opt_mul_8 = self.mul_8(self.mul_8_w, opt_div_7)
|
||||||
|
opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias)
|
||||||
|
return opt_add_9
|
||||||
|
|
||||||
|
|
||||||
|
class Module16(nn.Cell):
|
||||||
|
"""module of reader downstream"""
|
||||||
|
def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
|
||||||
|
normmodule_0_add_9_bias_shape):
|
||||||
|
"""init function"""
|
||||||
|
super(Module16, self).__init__()
|
||||||
|
self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
|
||||||
|
add_1_bias_shape=module15_0_add_1_bias_shape)
|
||||||
|
self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
|
||||||
|
add_9_bias_shape=normmodule_0_add_9_bias_shape)
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
module15_0_opt = self.module15_0(x)
|
||||||
|
normmodule_0_opt = self.normmodule_0(module15_0_opt)
|
||||||
|
opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
|
||||||
|
return ops.Cast()(opt_matmul_0, dst_type2)
|
||||||
|
|
||||||
|
|
||||||
|
class Module17(nn.Cell):
|
||||||
|
"""module of reader downstream"""
|
||||||
|
def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
|
||||||
|
normmodule_0_add_9_bias_shape):
|
||||||
|
"""init function"""
|
||||||
|
super(Module17, self).__init__()
|
||||||
|
self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
|
||||||
|
add_1_bias_shape=module15_0_add_1_bias_shape)
|
||||||
|
self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
|
||||||
|
add_9_bias_shape=normmodule_0_add_9_bias_shape)
|
||||||
|
self.matmul_0 = nn.MatMul()
|
||||||
|
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None)
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
module15_0_opt = self.module15_0(x)
|
||||||
|
normmodule_0_opt = self.normmodule_0(module15_0_opt)
|
||||||
|
opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
|
||||||
|
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
|
||||||
|
return opt_add_1
|
||||||
|
|
||||||
|
|
||||||
|
class Module5(nn.Cell):
|
||||||
|
"""module of reader downstream"""
|
||||||
|
def __init__(self):
|
||||||
|
"""init function"""
|
||||||
|
super(Module5, self).__init__()
|
||||||
|
self.sub_0 = P.Sub()
|
||||||
|
self.sub_0_bias = 1.0
|
||||||
|
self.mul_1 = P.Mul()
|
||||||
|
self.mul_1_w = 1.0000000150474662e+30
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_sub_0 = self.sub_0(self.sub_0_bias, x)
|
||||||
|
opt_mul_1 = self.mul_1(opt_sub_0, self.mul_1_w)
|
||||||
|
return opt_mul_1
|
||||||
|
|
||||||
|
|
||||||
|
class Module10(nn.Cell):
|
||||||
|
"""module of reader downstream"""
|
||||||
|
def __init__(self):
|
||||||
|
"""init function"""
|
||||||
|
super(Module10, self).__init__()
|
||||||
|
self.squeeze_0 = P.Squeeze(2)
|
||||||
|
self.module5_0 = Module5()
|
||||||
|
self.sub_1 = P.Sub()
|
||||||
|
|
||||||
|
def construct(self, x, x0):
|
||||||
|
"""construct function"""
|
||||||
|
opt_squeeze_0 = self.squeeze_0(x)
|
||||||
|
module5_0_opt = self.module5_0(x0)
|
||||||
|
opt_sub_1 = self.sub_1(opt_squeeze_0, module5_0_opt)
|
||||||
|
return opt_sub_1
|
||||||
|
|
||||||
|
|
||||||
|
class Reader_Downstream(nn.Cell):
|
||||||
|
"""Downstream model for reader"""
|
||||||
|
def __init__(self):
|
||||||
|
"""init function"""
|
||||||
|
super(Reader_Downstream, self).__init__()
|
||||||
|
self.module16_0 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
|
||||||
|
module15_0_add_1_bias_shape=(8192,),
|
||||||
|
normmodule_0_mul_8_w_shape=(8192,),
|
||||||
|
normmodule_0_add_9_bias_shape=(8192,))
|
||||||
|
self.add_74 = P.Add()
|
||||||
|
self.add_74_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
|
||||||
|
self.module16_1 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
|
||||||
|
module15_0_add_1_bias_shape=(8192,),
|
||||||
|
normmodule_0_mul_8_w_shape=(8192,),
|
||||||
|
normmodule_0_add_9_bias_shape=(8192,))
|
||||||
|
self.add_75 = P.Add()
|
||||||
|
self.add_75_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
|
||||||
|
self.module17_0 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
|
||||||
|
module15_0_add_1_bias_shape=(4096,),
|
||||||
|
normmodule_0_mul_8_w_shape=(4096,),
|
||||||
|
normmodule_0_add_9_bias_shape=(4096,))
|
||||||
|
self.module10_0 = Module10()
|
||||||
|
self.module17_1 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
|
||||||
|
module15_0_add_1_bias_shape=(4096,),
|
||||||
|
normmodule_0_mul_8_w_shape=(4096,),
|
||||||
|
normmodule_0_add_9_bias_shape=(4096,))
|
||||||
|
self.module10_1 = Module10()
|
||||||
|
self.gather_6_input_weight = Tensor(np.array(0))
|
||||||
|
self.gather_6_axis = 1
|
||||||
|
self.gather_6 = P.Gather()
|
||||||
|
self.dense_13 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True)
|
||||||
|
self.relu_18 = nn.ReLU()
|
||||||
|
self.normmodule_0 = NormModule(mul_8_w_shape=(4096,), add_9_bias_shape=(4096,))
|
||||||
|
self.dense_73 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True)
|
||||||
|
|
||||||
|
def construct(self, x, x0, x1, x2):
|
||||||
|
"""construct function"""
|
||||||
|
module16_0_opt = self.module16_0(x)
|
||||||
|
opt_add_74 = self.add_74(module16_0_opt, self.add_74_bias)
|
||||||
|
module16_1_opt = self.module16_1(x0)
|
||||||
|
opt_add_75 = self.add_75(module16_1_opt, self.add_75_bias)
|
||||||
|
module17_0_opt = self.module17_0(x1)
|
||||||
|
opt_module10_0 = self.module10_0(module17_0_opt, x2)
|
||||||
|
module17_1_opt = self.module17_1(x1)
|
||||||
|
opt_module10_1 = self.module10_1(module17_1_opt, x2)
|
||||||
|
opt_gather_6_axis = self.gather_6_axis
|
||||||
|
opt_gather_6 = self.gather_6(x1, self.gather_6_input_weight, opt_gather_6_axis)
|
||||||
|
opt_dense_13 = self.dense_13(opt_gather_6)
|
||||||
|
opt_relu_18 = self.relu_18(opt_dense_13)
|
||||||
|
normmodule_0_opt = self.normmodule_0(opt_relu_18)
|
||||||
|
opt_dense_73 = self.dense_73(normmodule_0_opt)
|
||||||
|
return opt_dense_73, opt_module10_0, opt_module10_1, opt_add_74, opt_add_75
|
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""execute reader"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
import random
|
||||||
|
from time import time
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import AlbertTokenizer
|
||||||
|
|
||||||
|
from mindspore import Tensor, ops
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
|
from src.rerank_and_reader_data_generator import DataGenerator
|
||||||
|
from src.rerank_and_reader_utils import convert_to_tokens, make_wiki_id, DocDB
|
||||||
|
from src.reader import Reader
|
||||||
|
|
||||||
|
|
||||||
|
def read(args):
|
||||||
|
"""reader function"""
|
||||||
|
db_file = args.wiki_db_file
|
||||||
|
reader_feature_file = args.reader_feature_file
|
||||||
|
reader_example_file = args.reader_example_file
|
||||||
|
encoder_ck_file = args.reader_encoder_ck_file
|
||||||
|
downstream_ck_file = args.reader_downstream_ck_file
|
||||||
|
albert_model_path = args.albert_model_path
|
||||||
|
reader_result_file = args.reader_result_file
|
||||||
|
seed = args.seed
|
||||||
|
sp_threshold = args.sp_threshold
|
||||||
|
seq_len = args.seq_len
|
||||||
|
batch_size = args.reader_batch_size
|
||||||
|
para_limit = args.max_para_num
|
||||||
|
sent_limit = args.max_sent_num
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
t1 = time()
|
||||||
|
|
||||||
|
doc_db = DocDB(db_file)
|
||||||
|
|
||||||
|
generator = DataGenerator(feature_file_path=reader_feature_file,
|
||||||
|
example_file_path=reader_example_file,
|
||||||
|
batch_size=batch_size, seq_len=seq_len,
|
||||||
|
para_limit=para_limit, sent_limit=sent_limit,
|
||||||
|
task_type="reader")
|
||||||
|
example_dict = generator.example_dict
|
||||||
|
feature_dict = generator.feature_dict
|
||||||
|
answer_dict = defaultdict(lambda: defaultdict(list))
|
||||||
|
new_answer_dict = {}
|
||||||
|
total_sp_dict = defaultdict(list)
|
||||||
|
new_total_sp_dict = defaultdict(list)
|
||||||
|
|
||||||
|
tokenizer = AlbertTokenizer.from_pretrained(albert_model_path)
|
||||||
|
new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]']
|
||||||
|
tokenizer.add_tokens(new_tokens)
|
||||||
|
|
||||||
|
reader = Reader(batch_size=batch_size,
|
||||||
|
encoder_ck_file=encoder_ck_file,
|
||||||
|
downstream_ck_file=downstream_ck_file)
|
||||||
|
|
||||||
|
print("start reading ...")
|
||||||
|
|
||||||
|
for _, batch in tqdm(enumerate(generator)):
|
||||||
|
input_ids = Tensor(batch["context_idxs"], mstype.int32)
|
||||||
|
attn_mask = Tensor(batch["context_mask"], mstype.int32)
|
||||||
|
token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
|
||||||
|
context_mask = Tensor(batch["context_mask"], mstype.float32)
|
||||||
|
square_mask = Tensor(batch["square_mask"], mstype.float32)
|
||||||
|
packing_mask = Tensor(batch["query_mapping"], mstype.float32)
|
||||||
|
para_start_mapping = Tensor(batch["para_start_mapping"], mstype.float32)
|
||||||
|
sent_end_mapping = Tensor(batch["sent_end_mapping"], mstype.float32)
|
||||||
|
unique_ids = batch["unique_ids"]
|
||||||
|
sent_names = batch["sent_names"]
|
||||||
|
cache_mask = Tensor(np.tril(np.triu(np.ones((seq_len, seq_len)), 0), 30), mstype.float32)
|
||||||
|
|
||||||
|
_, _, q_type, _, sent_logit, y1, y2 = reader(input_ids, attn_mask, token_type_ids,
|
||||||
|
context_mask, square_mask, packing_mask, cache_mask,
|
||||||
|
para_start_mapping, sent_end_mapping)
|
||||||
|
|
||||||
|
type_prob = ops.Softmax()(q_type).asnumpy()
|
||||||
|
|
||||||
|
answer_dict_ = convert_to_tokens(example_dict,
|
||||||
|
feature_dict,
|
||||||
|
batch['ids'],
|
||||||
|
y1.asnumpy().tolist(),
|
||||||
|
y2.asnumpy().tolist(),
|
||||||
|
type_prob,
|
||||||
|
tokenizer,
|
||||||
|
sent_logit.asnumpy(),
|
||||||
|
sent_names,
|
||||||
|
unique_ids)
|
||||||
|
for q_id in answer_dict_:
|
||||||
|
answer_dict[q_id] = answer_dict_[q_id]
|
||||||
|
|
||||||
|
for q_id in answer_dict:
|
||||||
|
res = answer_dict[q_id]
|
||||||
|
answer_text_ = res[0]
|
||||||
|
sent_ = res[1]
|
||||||
|
sent_names_ = res[2]
|
||||||
|
new_answer_dict[q_id] = answer_text_
|
||||||
|
|
||||||
|
predict_support_np = ops.Sigmoid()(Tensor(sent_, mstype.float32)).asnumpy()
|
||||||
|
|
||||||
|
for j in range(predict_support_np.shape[0]):
|
||||||
|
if j >= len(sent_names_):
|
||||||
|
break
|
||||||
|
if predict_support_np[j] > sp_threshold:
|
||||||
|
total_sp_dict[q_id].append(sent_names_[j])
|
||||||
|
|
||||||
|
for _id in total_sp_dict:
|
||||||
|
_sent_names = total_sp_dict[_id]
|
||||||
|
for para in _sent_names:
|
||||||
|
title = make_wiki_id(para[0], 0)
|
||||||
|
para_original_title = doc_db.get_doc_info(title)[-1]
|
||||||
|
para[0] = para_original_title
|
||||||
|
new_total_sp_dict[_id].append(para)
|
||||||
|
|
||||||
|
prediction = {'answer': new_answer_dict,
|
||||||
|
'sp': new_total_sp_dict}
|
||||||
|
|
||||||
|
with open(reader_result_file, 'w') as f:
|
||||||
|
json.dump(prediction, f, indent=4)
|
||||||
|
|
||||||
|
t2 = time()
|
||||||
|
|
||||||
|
print(f"reader cost time: {t2-t1} s")
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,183 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""define a data generator"""
|
||||||
|
|
||||||
|
import gzip
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator:
|
||||||
|
"""data generator for reranker and reader"""
|
||||||
|
def __init__(self, feature_file_path, example_file_path, batch_size, seq_len,
|
||||||
|
para_limit=None, sent_limit=None, task_type=None):
|
||||||
|
"""init function"""
|
||||||
|
self.example_ptr = 0
|
||||||
|
self.bsz = batch_size
|
||||||
|
self.seq_length = seq_len
|
||||||
|
self.para_limit = para_limit
|
||||||
|
self.sent_limit = sent_limit
|
||||||
|
self.task_type = task_type
|
||||||
|
|
||||||
|
self.feature_file_path = feature_file_path
|
||||||
|
self.example_file_path = example_file_path
|
||||||
|
self.features = self.load_features()
|
||||||
|
self.examples = self.load_examples()
|
||||||
|
self.feature_dict = self.get_feature_dict()
|
||||||
|
self.example_dict = self.get_example_dict()
|
||||||
|
|
||||||
|
self.features = self.padding_feature(self.features, self.bsz)
|
||||||
|
|
||||||
|
def load_features(self):
|
||||||
|
"""load features from feature file"""
|
||||||
|
with gzip.open(self.feature_file_path, 'rb') as fin:
|
||||||
|
features = pickle.load(fin)
|
||||||
|
print("load features successful !!!")
|
||||||
|
return features
|
||||||
|
|
||||||
|
def padding_feature(self, features, bsz):
|
||||||
|
"""padding features as multiples of batch size"""
|
||||||
|
padding_num = ((len(features) // bsz + 1) * bsz - len(features))
|
||||||
|
print(f"features padding num is {padding_num}")
|
||||||
|
new_features = features + features[:padding_num]
|
||||||
|
return new_features
|
||||||
|
|
||||||
|
def load_examples(self):
|
||||||
|
"""laod examples from file"""
|
||||||
|
if self.example_file_path:
|
||||||
|
with gzip.open(self.example_file_path, 'rb') as fin:
|
||||||
|
examples = pickle.load(fin)
|
||||||
|
print("load examples successful !!!")
|
||||||
|
return examples
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_feature_dict(self):
|
||||||
|
"""build a feature dict"""
|
||||||
|
return {f.unique_id: f for f in self.features}
|
||||||
|
|
||||||
|
def get_example_dict(self):
|
||||||
|
"""build a example dict"""
|
||||||
|
if self.example_file_path:
|
||||||
|
return {e.unique_id: e for e in self.examples}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def common_process_single_case(self, i, case, context_idxs, context_mask, segment_idxs, ids, path, unique_ids):
|
||||||
|
"""common process for a single case"""
|
||||||
|
context_idxs[i] = np.array(case.doc_input_ids)
|
||||||
|
context_mask[i] = np.array(case.doc_input_mask)
|
||||||
|
segment_idxs[i] = np.array(case.doc_segment_ids)
|
||||||
|
|
||||||
|
ids.append(case.qas_id)
|
||||||
|
path.append(case.path)
|
||||||
|
unique_ids.append(case.unique_id)
|
||||||
|
|
||||||
|
return context_idxs, context_mask, segment_idxs, ids, path, unique_ids
|
||||||
|
|
||||||
|
def reader_process_single_case(self, i, case, sent_names, square_mask, query_mapping, ques_start_mapping,
|
||||||
|
para_start_mapping, sent_end_mapping):
|
||||||
|
"""process for a single case about reader"""
|
||||||
|
sent_names.append(case.sent_names)
|
||||||
|
prev_position = None
|
||||||
|
for cur_position, token_id in enumerate(case.doc_input_ids):
|
||||||
|
if token_id >= 30000:
|
||||||
|
if prev_position:
|
||||||
|
square_mask[i, prev_position + 1: cur_position, prev_position + 1: cur_position] = 1.0
|
||||||
|
prev_position = cur_position
|
||||||
|
if case.sent_spans:
|
||||||
|
for j in range(case.sent_spans[0][0] - 1):
|
||||||
|
query_mapping[i, j] = 1
|
||||||
|
ques_start_mapping[i, 0, 1] = 1
|
||||||
|
for j, para_span in enumerate(case.para_spans[:self.para_limit]):
|
||||||
|
start, end, _ = para_span
|
||||||
|
if start <= end:
|
||||||
|
para_start_mapping[i, j, start] = 1
|
||||||
|
for j, sent_span in enumerate(case.sent_spans[:self.sent_limit]):
|
||||||
|
start, end = sent_span
|
||||||
|
if start <= end:
|
||||||
|
end = min(end, self.seq_length - 1)
|
||||||
|
sent_end_mapping[i, j, end] = 1
|
||||||
|
return sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""iteration function"""
|
||||||
|
while True:
|
||||||
|
if self.example_ptr >= len(self.features):
|
||||||
|
break
|
||||||
|
start_id = self.example_ptr
|
||||||
|
cur_bsz = min(self.bsz, len(self.features) - start_id)
|
||||||
|
cur_batch = self.features[start_id: start_id + cur_bsz]
|
||||||
|
# BERT input
|
||||||
|
context_idxs = np.zeros((cur_bsz, self.seq_length))
|
||||||
|
context_mask = np.zeros((cur_bsz, self.seq_length))
|
||||||
|
segment_idxs = np.zeros((cur_bsz, self.seq_length))
|
||||||
|
|
||||||
|
# others
|
||||||
|
ids = []
|
||||||
|
path = []
|
||||||
|
unique_ids = []
|
||||||
|
|
||||||
|
if self.task_type == "reader":
|
||||||
|
# Mappings
|
||||||
|
ques_start_mapping = np.zeros((cur_bsz, 1, self.seq_length))
|
||||||
|
query_mapping = np.zeros((cur_bsz, self.seq_length))
|
||||||
|
para_start_mapping = np.zeros((cur_bsz, self.para_limit, self.seq_length))
|
||||||
|
sent_end_mapping = np.zeros((cur_bsz, self.sent_limit, self.seq_length))
|
||||||
|
square_mask = np.zeros((cur_bsz, self.seq_length, self.seq_length))
|
||||||
|
sent_names = []
|
||||||
|
|
||||||
|
for i, case in enumerate(cur_batch):
|
||||||
|
context_idxs, context_mask, segment_idxs, ids, path, unique_ids = \
|
||||||
|
self.common_process_single_case(i, case, context_idxs, context_mask, segment_idxs, ids, path,
|
||||||
|
unique_ids)
|
||||||
|
if self.task_type == "reader":
|
||||||
|
sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping = \
|
||||||
|
self.reader_process_single_case(i, case, sent_names, square_mask, query_mapping,
|
||||||
|
ques_start_mapping, para_start_mapping, sent_end_mapping)
|
||||||
|
|
||||||
|
self.example_ptr += cur_bsz
|
||||||
|
|
||||||
|
if self.task_type == "reranker":
|
||||||
|
yield {
|
||||||
|
"context_idxs": context_idxs,
|
||||||
|
"context_mask": context_mask,
|
||||||
|
"segment_idxs": segment_idxs,
|
||||||
|
|
||||||
|
"ids": ids,
|
||||||
|
"unique_ids": unique_ids,
|
||||||
|
"path": path
|
||||||
|
}
|
||||||
|
elif self.task_type == "reader":
|
||||||
|
yield {
|
||||||
|
"context_idxs": context_idxs,
|
||||||
|
"context_mask": context_mask,
|
||||||
|
"segment_idxs": segment_idxs,
|
||||||
|
"query_mapping": query_mapping,
|
||||||
|
"para_start_mapping": para_start_mapping,
|
||||||
|
"sent_end_mapping": sent_end_mapping,
|
||||||
|
"square_mask": square_mask,
|
||||||
|
"ques_start_mapping": ques_start_mapping,
|
||||||
|
|
||||||
|
"ids": ids,
|
||||||
|
"unique_ids": unique_ids,
|
||||||
|
"sent_names": sent_names,
|
||||||
|
"path": path
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print(f"data generator received a error type: {self.task_type} !!!")
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""downstream Model for reranker"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class Rerank_Downstream(nn.Cell):
|
||||||
|
"""Downstream model for rerank"""
|
||||||
|
def __init__(self):
|
||||||
|
"""init function"""
|
||||||
|
super(Rerank_Downstream, self).__init__()
|
||||||
|
self.dense_0 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True)
|
||||||
|
self.relu_1 = nn.ReLU()
|
||||||
|
self.reducemean_2 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.sub_3 = P.Sub()
|
||||||
|
self.sub_4 = P.Sub()
|
||||||
|
self.pow_5 = P.Pow()
|
||||||
|
self.pow_5_input_weight = 2.0
|
||||||
|
self.reducemean_6 = P.ReduceMean(keep_dims=True)
|
||||||
|
self.add_7 = P.Add()
|
||||||
|
self.add_7_bias = 9.999999960041972e-13
|
||||||
|
self.sqrt_8 = P.Sqrt()
|
||||||
|
self.div_9 = P.Div()
|
||||||
|
self.mul_10 = P.Mul()
|
||||||
|
self.mul_10_w = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
|
||||||
|
self.add_11 = P.Add()
|
||||||
|
self.add_11_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
|
||||||
|
self.dense_12 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct function"""
|
||||||
|
opt_dense_0 = self.dense_0(x)
|
||||||
|
opt_relu_1 = self.relu_1(opt_dense_0)
|
||||||
|
opt_reducemean_2 = self.reducemean_2(opt_relu_1, -1)
|
||||||
|
opt_sub_3 = self.sub_3(opt_relu_1, opt_reducemean_2)
|
||||||
|
opt_sub_4 = self.sub_4(opt_relu_1, opt_reducemean_2)
|
||||||
|
opt_pow_5 = self.pow_5(opt_sub_3, self.pow_5_input_weight)
|
||||||
|
opt_reducemean_6 = self.reducemean_6(opt_pow_5, -1)
|
||||||
|
opt_add_7 = self.add_7(opt_reducemean_6, self.add_7_bias)
|
||||||
|
opt_sqrt_8 = self.sqrt_8(opt_add_7)
|
||||||
|
opt_div_9 = self.div_9(opt_sub_4, opt_sqrt_8)
|
||||||
|
opt_mul_10 = self.mul_10(self.mul_10_w, opt_div_9)
|
||||||
|
opt_add_11 = self.add_11(opt_mul_10, self.add_11_bias)
|
||||||
|
opt_dense_12 = self.dense_12(opt_add_11)
|
||||||
|
return opt_dense_12
|
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Reranker Model"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import load_checkpoint, load_param_into_net
|
||||||
|
from src.rerank_albert_xxlarge import Rerank_Albert
|
||||||
|
from src.rerank_downstream import Rerank_Downstream
|
||||||
|
|
||||||
|
|
||||||
|
class Reranker(nn.Cell):
|
||||||
|
"""Reranker model"""
|
||||||
|
def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
|
||||||
|
"""init function"""
|
||||||
|
super(Reranker, self).__init__(auto_prefix=False)
|
||||||
|
|
||||||
|
self.encoder = Rerank_Albert(batch_size)
|
||||||
|
param_dict = load_checkpoint(encoder_ck_file)
|
||||||
|
not_load_params_1 = load_param_into_net(self.encoder, param_dict)
|
||||||
|
print(f"not loaded albert: {not_load_params_1}")
|
||||||
|
|
||||||
|
self.no_answer_mlp = Rerank_Downstream()
|
||||||
|
param_dict = load_checkpoint(downstream_ck_file)
|
||||||
|
not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict)
|
||||||
|
print(f"not loaded downstream: {not_load_params_2}")
|
||||||
|
|
||||||
|
def construct(self, input_ids, attn_mask, token_type_ids):
|
||||||
|
"""construct function"""
|
||||||
|
state = self.encoder(input_ids, attn_mask, token_type_ids)
|
||||||
|
state = state[:, 0, :]
|
||||||
|
|
||||||
|
no_answer = self.no_answer_mlp(state)
|
||||||
|
return no_answer
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""execute reranker"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from time import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, ops
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
|
from src.rerank_and_reader_data_generator import DataGenerator
|
||||||
|
from src.reranker import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
def rerank(args):
|
||||||
|
"""rerank function"""
|
||||||
|
rerank_feature_file = args.rerank_feature_file
|
||||||
|
rerank_result_file = args.rerank_result_file
|
||||||
|
encoder_ck_file = args.rerank_encoder_ck_file
|
||||||
|
downstream_ck_file = args.rerank_downstream_ck_file
|
||||||
|
seed = args.seed
|
||||||
|
seq_len = args.seq_len
|
||||||
|
batch_size = args.rerank_batch_size
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
t1 = time()
|
||||||
|
|
||||||
|
generator = DataGenerator(feature_file_path=rerank_feature_file,
|
||||||
|
example_file_path=None,
|
||||||
|
batch_size=batch_size, seq_len=seq_len,
|
||||||
|
task_type="reranker")
|
||||||
|
gather_dict = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
reranker = Reranker(batch_size=batch_size,
|
||||||
|
encoder_ck_file=encoder_ck_file,
|
||||||
|
downstream_ck_file=downstream_ck_file)
|
||||||
|
|
||||||
|
print("start re-ranking ...")
|
||||||
|
|
||||||
|
for _, batch in tqdm(enumerate(generator)):
|
||||||
|
input_ids = Tensor(batch["context_idxs"], mstype.int32)
|
||||||
|
attn_mask = Tensor(batch["context_mask"], mstype.int32)
|
||||||
|
token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
|
||||||
|
|
||||||
|
no_answer = reranker(input_ids, attn_mask, token_type_ids)
|
||||||
|
|
||||||
|
no_answer_prob = ops.Softmax()(no_answer).asnumpy()
|
||||||
|
no_answer_prob = no_answer_prob[:, 0]
|
||||||
|
|
||||||
|
for i in range(len(batch['ids'])):
|
||||||
|
qas_id = batch['ids'][i]
|
||||||
|
gather_dict[qas_id][no_answer_prob[i]].append(batch['unique_ids'][i])
|
||||||
|
gather_dict[qas_id][no_answer_prob[i]].append(batch['path'][i])
|
||||||
|
|
||||||
|
rerank_result = {}
|
||||||
|
for qas_id in tqdm(gather_dict, desc="get top1 path from re-rank result"):
|
||||||
|
all_paths = gather_dict[qas_id]
|
||||||
|
all_paths = sorted(all_paths.items(), key=lambda item: item[0])
|
||||||
|
assert qas_id not in rerank_result
|
||||||
|
rerank_result[qas_id] = all_paths[0][1]
|
||||||
|
|
||||||
|
with open(rerank_result_file, 'w') as f:
|
||||||
|
json.dump(rerank_result, f)
|
||||||
|
|
||||||
|
t2 = time()
|
||||||
|
|
||||||
|
print(f"re-rank cost time: {t2-t1} s")
|
Loading…
Reference in new issue