!13375 add reranker and reader into modelzoo/research/tprr

From: @huenrui
Reviewed-by: @oacjiewen,@guoqi1024
Signed-off-by:
pull/13375/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5313abfb66

@ -38,6 +38,9 @@ Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyper
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia
through TF-IDF.
The dataset of re-ranker consists of two parts:
Wikipedia data: the 2017 English Wikipedia dump version.
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
# [Features](#contents)
@ -64,6 +67,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
```python
# run evaluation example with HotPotQA dev dataset
sh run_eval_ascend.sh
sh run_eval_ascend_reranker_reader.sh
```
# [Script Description](#contents)
@ -75,25 +79,39 @@ After installing MindSpore via the official website and Dataset is correctly gen
└─tprr
├─README.md
├─scripts
| ├─run_eval_ascend.sh # Launch evaluation in ascend
| ├─run_eval_ascend.sh # Launch retriever evaluation in ascend
| └─run_eval_ascend_reranker_reader # Launch re-ranker and reader evaluation in ascend
|
├─src
| ├─config.py # Evaluation configurations
| ├─onehop.py # Onehop model
| ├─onehop_bert.py # Onehop bert model
| ├─process_data.py # Data preprocessing
| ├─twohop.py # Twohop model
| ├─twohop_bert.py # Twohop bert model
| └─utils.py # Utils for evaluation
| ├─build_reranker_data.py # build data for re-ranker from result of retriever
| ├─config.py # Evaluation configurations for retriever
| ├─hotpot_evaluate_v1.py # Hotpotqa evaluation script
| ├─onehop.py # Onehop model of retriever
| ├─onehop_bert.py # Onehop bert model of retriever
| ├─process_data.py # Data preprocessing for retriever
| ├─reader.py # Reader model
| ├─reader_albert_xxlarge.py # Albert-xxlarge module of reader model
| ├─reader_downstream.py # Downstream module of reader model
| ├─reader_eval.py # Reader evaluation script
| ├─rerank_albert_xxlarge.py # Albert-xxlarge module of re-ranker model
| ├─rerank_and_reader_data_generator.py # Data generator for re-ranker and reader
| ├─rerank_and_reader_utils.py # Utils for re-ranker and reader
| ├─rerank_downstream.py # Downstream module of re-ranker model
| ├─reranker.py # Re-ranker model
| ├─reranker_eval.py # Re-ranker evaluation script
| ├─twohop.py # Twohop model of retriever
| ├─twohop_bert.py # Twohop bert model of retriever
| └─utils.py # Utils for retriever
|
└─retriever_eval.py # Evaluation net for retriever
├─retriever_eval.py # Evaluation net for retriever
└─reranker_and_reader_eval.py # Evaluation net for re-ranker and reader
```
## [Script Parameters](#contents)
Parameters for evaluation can be set in config.py.
Parameters for retriever evaluation can be set in config.py.
- config for TPRR retriever dataset
- config for TPRR retriever
```python
"q_len": 64, # Max query length
@ -108,17 +126,30 @@ Parameters for evaluation can be set in config.py.
config.py for more configuration.
Parameters for re-ranker and reader evaluation can be passed directly at execution time.
- parameters for TPRR re-ranker and reader
```python
"seq_len": 512, # sequence length
"rerank_batch_size": 32, # batch size for re-ranker evaluation
"reader_batch_size": 448, # batch size for reader evaluation
"sp_threshold": 8 # threshold for picking supporting sentence
```
config.py for more configuration.
## [Evaluation Process](#contents)
### Evaluation
- Evaluation on Ascend
- Retriever evaluation on Ascend
```python
sh run_eval_ascend.sh
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
Evaluation result will be stored in the scripts path, whose folder name begins with "eval_tr". You can find the result like the
followings in log.
```python
@ -138,6 +169,35 @@ Parameters for evaluation can be set in config.py.
evaluation time (h): 20.155506462653477
```
- Re-ranker and reader evaluation on Ascend
Use the output of retriever as input of re-ranker
```python
sh run_eval_ascend_reranker_reader.sh
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
followings in log.
```python
total top1 pem: 0.8803511141120864
...
em: 0.67440918298447
f1: 0.8025625656569652
prec: 0.8292800393689271
recall: 0.8136908451841731
sp_em: 0.6009453072248481
sp_f1: 0.844555664157302
sp_prec: 0.8640844345841021
sp_recall: 0.8446123918845106
joint_em: 0.4537474679270763
joint_f1: 0.715119580346802
joint_prec: 0.7540052057184267
joint_recall: 0.7250240424067661
```
# [Model Description](#contents)
## [Performance](#contents)
@ -154,6 +214,8 @@ Parameters for evaluation can be set in config.py.
| Batch_size | 1 |
| Output | inference path |
| PEM | 0.9188 |
| total top1 pem | 0.88 |
| joint_f1 | 0.7151 |
# [Description of random situation](#contents)

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main file"""
from mindspore import context
from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data
from src.reranker_eval import rerank
from src.reader_eval import read
from src.hotpot_evaluate_v1 import hotpotqa_eval
from src.build_reranker_data import get_rerank_data
def rerank_and_retriever_eval():
"""main function"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
parser = get_parse()
args = parser.parse_args()
if args.get_reranker_data:
get_rerank_data(args)
if args.run_reranker:
rerank(args)
if args.cal_reranker_metrics:
total_top1_pem, _, _ = \
cal_reranker_metrics(dev_gold_file=args.dev_gold_file, rerank_result_file=args.rerank_result_file)
print(f"total top1 pem: {total_top1_pem}")
if args.select_reader_data:
select_reader_dev_data(args)
if args.run_reader:
read(args)
if args.cal_reader_metrics:
metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_file)
for k in metrics:
print(f"{k}: {metrics[k]}")
if __name__ == "__main__":
rerank_and_retriever_eval()

@ -21,16 +21,16 @@ export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
if [ -d "eval_tr" ];
then
rm -rf ./eval
rm -rf ./eval_tr
fi
mkdir ./eval
mkdir ./eval_tr
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
cp ../*.py ./eval_tr
cp *.sh ./eval_tr
cp -r ../src ./eval_tr
cd ./eval_tr || exit
env > env.log
echo "start evaluation"

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# eval script
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation"
python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics > log_reranker_and_reader.txt 2>&1 &
cd ..

File diff suppressed because it is too large Load Diff

@ -33,14 +33,14 @@ def ThinkRetrieverConfig():
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
parser.add_argument("--save_path", type=str, default='./', help='path of output')
parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path")
parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path")
parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json',
parser.add_argument("--save_path", type=str, default='../', help='path of output')
parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path")
parser.add_argument("--wiki_path", type=str, default='../db_docs_bidirection_new.pkl', help="wiki path")
parser.add_argument("--dev_path", type=str, default='../hotpot_dev_fullwiki_v1_for_retriever.json',
help="dev path")
parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path")
parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path")
parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path")
parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path")
parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path")
parser.add_argument("--dev_data_path", type=str, default='../dev_tf_idf_data_raw.pkl', help="dev data path")
parser.add_argument("--onehop_bert_path", type=str, default='../onehop.ckpt', help="onehop bert ckpt path")
parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path")
parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path")
parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path")
return parser.parse_args()

@ -0,0 +1,153 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hotpotqa evaluate script"""
import re
import string
from collections import Counter
import ujson as json
def normalize_answer(s):
"""normalize answer"""
def remove_articles(text):
"""remove articles"""
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
"""fix whitespace"""
return ' '.join(text.split())
def remove_punc(text):
"""remove punctuation from text"""
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
"""lower text"""
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
"""calculate f1 score"""
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
ZERO_METRIC = (0, 0, 0)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall
def exact_match_score(prediction, ground_truth):
"""calculate exact match score"""
return normalize_answer(prediction) == normalize_answer(ground_truth)
def update_answer(metrics, prediction, gold):
"""update answer"""
em = exact_match_score(prediction, gold)
f1, prec, recall = f1_score(prediction, gold)
metrics['em'] += float(em)
metrics['f1'] += f1
metrics['prec'] += prec
metrics['recall'] += recall
return em, prec, recall
def update_sp(metrics, prediction, gold):
"""update supporting sentences"""
cur_sp_pred = set(map(tuple, prediction))
gold_sp_pred = set(map(tuple, gold))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
metrics['sp_em'] += em
metrics['sp_f1'] += f1
metrics['sp_prec'] += prec
metrics['sp_recall'] += recall
return em, prec, recall
def hotpotqa_eval(prediction_file, gold_file):
"""hotpotqa evaluate function"""
with open(prediction_file) as f:
prediction = json.load(f)
with open(gold_file) as f:
gold = json.load(f)
metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
for dp in gold:
cur_id = dp['_id']
can_eval_joint = True
if cur_id not in prediction['answer']:
print('missing answer {}'.format(cur_id))
can_eval_joint = False
else:
em, prec, recall = update_answer(
metrics, prediction['answer'][cur_id], dp['answer'])
if cur_id not in prediction['sp']:
print('missing sp fact {}'.format(cur_id))
can_eval_joint = False
else:
sp_em, sp_prec, sp_recall = update_sp(
metrics, prediction['sp'][cur_id], dp['supporting_facts'])
if can_eval_joint:
joint_prec = prec * sp_prec
joint_recall = recall * sp_recall
if joint_prec + joint_recall > 0:
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
else:
joint_f1 = 0.
joint_em = em * sp_em
metrics['joint_em'] += joint_em
metrics['joint_f1'] += joint_f1
metrics['joint_prec'] += joint_prec
metrics['joint_recall'] += joint_recall
num = len(gold)
for k in metrics:
metrics[k] /= num
return metrics

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reader model"""
import mindspore.nn as nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.ops import BatchMatMul
from mindspore import ops
from mindspore import dtype as mstype
from src.reader_albert_xxlarge import Reader_Albert
from src.reader_downstream import Reader_Downstream
dst_type = mstype.float16
dst_type2 = mstype.float32
class Reader(nn.Cell):
"""Reader model"""
def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
"""init function"""
super(Reader, self).__init__(auto_prefix=False)
self.encoder = Reader_Albert(batch_size)
param_dict = load_checkpoint(encoder_ck_file)
not_load_params = load_param_into_net(self.encoder, param_dict)
print(f"not loaded: {not_load_params}")
self.downstream = Reader_Downstream()
param_dict = load_checkpoint(downstream_ck_file)
not_load_params = load_param_into_net(self.downstream, param_dict)
print(f"not loaded: {not_load_params}")
self.bmm = BatchMatMul()
def construct(self, input_ids, attn_mask, token_type_ids,
context_mask, square_mask, packing_mask, cache_mask,
para_start_mapping, sent_end_mapping):
"""construct function"""
state = self.encoder(attn_mask, input_ids, token_type_ids)
para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D]
sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D]
q_type, start, end, para_logit, sent_logit = self.downstream(ops.Cast()(para_state, dst_type2),
ops.Cast()(sent_state, dst_type2),
state,
context_mask)
outer = start[:, :, None] + end[:, None]
outer_mask = cache_mask
outer_mask = square_mask * outer_mask[None]
outer = outer - 1e30 * (1 - outer_mask)
outer = outer - 1e30 * packing_mask[:, :, None]
max_row = ops.ReduceMax()(outer, 2)
y1 = ops.Argmax()(max_row)
max_col = ops.ReduceMax()(outer, 1)
y2 = ops.Argmax()(max_col)
return start, end, q_type, para_logit, sent_logit, y1, y2

File diff suppressed because it is too large Load Diff

@ -0,0 +1,213 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""downstream Model for reader"""
import numpy as np
from mindspore import nn, ops
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore import dtype as mstype
dst_type = mstype.float16
dst_type2 = mstype.float32
class Module15(nn.Cell):
"""module of reader downstream"""
def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
"""init function"""
super(Module15, self).__init__()
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
name=None)
self.add_1 = P.Add()
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
self.relu_2 = nn.ReLU()
def construct(self, x):
"""construct function"""
opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
opt_relu_2 = self.relu_2(opt_add_1)
return opt_relu_2
class NormModule(nn.Cell):
"""Normalization module of reader downstream"""
def __init__(self, mul_8_w_shape, add_9_bias_shape):
"""init function"""
super(NormModule, self).__init__()
self.reducemean_0 = P.ReduceMean(keep_dims=True)
self.sub_1 = P.Sub()
self.sub_2 = P.Sub()
self.pow_3 = P.Pow()
self.pow_3_input_weight = 2.0
self.reducemean_4 = P.ReduceMean(keep_dims=True)
self.add_5 = P.Add()
self.add_5_bias = 9.999999960041972e-13
self.sqrt_6 = P.Sqrt()
self.div_7 = P.Div()
self.mul_8 = P.Mul()
self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, mul_8_w_shape).astype(np.float32)), name=None)
self.add_9 = P.Add()
self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, add_9_bias_shape).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
opt_reducemean_0 = self.reducemean_0(x, -1)
opt_sub_1 = self.sub_1(x, opt_reducemean_0)
opt_sub_2 = self.sub_2(x, opt_reducemean_0)
opt_pow_3 = self.pow_3(opt_sub_1, self.pow_3_input_weight)
opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1)
opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias)
opt_sqrt_6 = self.sqrt_6(opt_add_5)
opt_div_7 = self.div_7(opt_sub_2, opt_sqrt_6)
opt_mul_8 = self.mul_8(self.mul_8_w, opt_div_7)
opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias)
return opt_add_9
class Module16(nn.Cell):
"""module of reader downstream"""
def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
normmodule_0_add_9_bias_shape):
"""init function"""
super(Module16, self).__init__()
self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
add_1_bias_shape=module15_0_add_1_bias_shape)
self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
add_9_bias_shape=normmodule_0_add_9_bias_shape)
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
module15_0_opt = self.module15_0(x)
normmodule_0_opt = self.normmodule_0(module15_0_opt)
opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
return ops.Cast()(opt_matmul_0, dst_type2)
class Module17(nn.Cell):
"""module of reader downstream"""
def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
normmodule_0_add_9_bias_shape):
"""init function"""
super(Module17, self).__init__()
self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
add_1_bias_shape=module15_0_add_1_bias_shape)
self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
add_9_bias_shape=normmodule_0_add_9_bias_shape)
self.matmul_0 = nn.MatMul()
self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None)
self.add_1 = P.Add()
self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
def construct(self, x):
"""construct function"""
module15_0_opt = self.module15_0(x)
normmodule_0_opt = self.normmodule_0(module15_0_opt)
opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
return opt_add_1
class Module5(nn.Cell):
"""module of reader downstream"""
def __init__(self):
"""init function"""
super(Module5, self).__init__()
self.sub_0 = P.Sub()
self.sub_0_bias = 1.0
self.mul_1 = P.Mul()
self.mul_1_w = 1.0000000150474662e+30
def construct(self, x):
"""construct function"""
opt_sub_0 = self.sub_0(self.sub_0_bias, x)
opt_mul_1 = self.mul_1(opt_sub_0, self.mul_1_w)
return opt_mul_1
class Module10(nn.Cell):
"""module of reader downstream"""
def __init__(self):
"""init function"""
super(Module10, self).__init__()
self.squeeze_0 = P.Squeeze(2)
self.module5_0 = Module5()
self.sub_1 = P.Sub()
def construct(self, x, x0):
"""construct function"""
opt_squeeze_0 = self.squeeze_0(x)
module5_0_opt = self.module5_0(x0)
opt_sub_1 = self.sub_1(opt_squeeze_0, module5_0_opt)
return opt_sub_1
class Reader_Downstream(nn.Cell):
"""Downstream model for reader"""
def __init__(self):
"""init function"""
super(Reader_Downstream, self).__init__()
self.module16_0 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
module15_0_add_1_bias_shape=(8192,),
normmodule_0_mul_8_w_shape=(8192,),
normmodule_0_add_9_bias_shape=(8192,))
self.add_74 = P.Add()
self.add_74_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
self.module16_1 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
module15_0_add_1_bias_shape=(8192,),
normmodule_0_mul_8_w_shape=(8192,),
normmodule_0_add_9_bias_shape=(8192,))
self.add_75 = P.Add()
self.add_75_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
self.module17_0 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
module15_0_add_1_bias_shape=(4096,),
normmodule_0_mul_8_w_shape=(4096,),
normmodule_0_add_9_bias_shape=(4096,))
self.module10_0 = Module10()
self.module17_1 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
module15_0_add_1_bias_shape=(4096,),
normmodule_0_mul_8_w_shape=(4096,),
normmodule_0_add_9_bias_shape=(4096,))
self.module10_1 = Module10()
self.gather_6_input_weight = Tensor(np.array(0))
self.gather_6_axis = 1
self.gather_6 = P.Gather()
self.dense_13 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True)
self.relu_18 = nn.ReLU()
self.normmodule_0 = NormModule(mul_8_w_shape=(4096,), add_9_bias_shape=(4096,))
self.dense_73 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True)
def construct(self, x, x0, x1, x2):
"""construct function"""
module16_0_opt = self.module16_0(x)
opt_add_74 = self.add_74(module16_0_opt, self.add_74_bias)
module16_1_opt = self.module16_1(x0)
opt_add_75 = self.add_75(module16_1_opt, self.add_75_bias)
module17_0_opt = self.module17_0(x1)
opt_module10_0 = self.module10_0(module17_0_opt, x2)
module17_1_opt = self.module17_1(x1)
opt_module10_1 = self.module10_1(module17_1_opt, x2)
opt_gather_6_axis = self.gather_6_axis
opt_gather_6 = self.gather_6(x1, self.gather_6_input_weight, opt_gather_6_axis)
opt_dense_13 = self.dense_13(opt_gather_6)
opt_relu_18 = self.relu_18(opt_dense_13)
normmodule_0_opt = self.normmodule_0(opt_relu_18)
opt_dense_73 = self.dense_73(normmodule_0_opt)
return opt_dense_73, opt_module10_0, opt_module10_1, opt_add_74, opt_add_75

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""execute reader"""
from collections import defaultdict
import random
from time import time
import json
from tqdm import tqdm
import numpy as np
from transformers import AlbertTokenizer
from mindspore import Tensor, ops
from mindspore import dtype as mstype
from src.rerank_and_reader_data_generator import DataGenerator
from src.rerank_and_reader_utils import convert_to_tokens, make_wiki_id, DocDB
from src.reader import Reader
def read(args):
"""reader function"""
db_file = args.wiki_db_file
reader_feature_file = args.reader_feature_file
reader_example_file = args.reader_example_file
encoder_ck_file = args.reader_encoder_ck_file
downstream_ck_file = args.reader_downstream_ck_file
albert_model_path = args.albert_model_path
reader_result_file = args.reader_result_file
seed = args.seed
sp_threshold = args.sp_threshold
seq_len = args.seq_len
batch_size = args.reader_batch_size
para_limit = args.max_para_num
sent_limit = args.max_sent_num
random.seed(seed)
np.random.seed(seed)
t1 = time()
doc_db = DocDB(db_file)
generator = DataGenerator(feature_file_path=reader_feature_file,
example_file_path=reader_example_file,
batch_size=batch_size, seq_len=seq_len,
para_limit=para_limit, sent_limit=sent_limit,
task_type="reader")
example_dict = generator.example_dict
feature_dict = generator.feature_dict
answer_dict = defaultdict(lambda: defaultdict(list))
new_answer_dict = {}
total_sp_dict = defaultdict(list)
new_total_sp_dict = defaultdict(list)
tokenizer = AlbertTokenizer.from_pretrained(albert_model_path)
new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]']
tokenizer.add_tokens(new_tokens)
reader = Reader(batch_size=batch_size,
encoder_ck_file=encoder_ck_file,
downstream_ck_file=downstream_ck_file)
print("start reading ...")
for _, batch in tqdm(enumerate(generator)):
input_ids = Tensor(batch["context_idxs"], mstype.int32)
attn_mask = Tensor(batch["context_mask"], mstype.int32)
token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
context_mask = Tensor(batch["context_mask"], mstype.float32)
square_mask = Tensor(batch["square_mask"], mstype.float32)
packing_mask = Tensor(batch["query_mapping"], mstype.float32)
para_start_mapping = Tensor(batch["para_start_mapping"], mstype.float32)
sent_end_mapping = Tensor(batch["sent_end_mapping"], mstype.float32)
unique_ids = batch["unique_ids"]
sent_names = batch["sent_names"]
cache_mask = Tensor(np.tril(np.triu(np.ones((seq_len, seq_len)), 0), 30), mstype.float32)
_, _, q_type, _, sent_logit, y1, y2 = reader(input_ids, attn_mask, token_type_ids,
context_mask, square_mask, packing_mask, cache_mask,
para_start_mapping, sent_end_mapping)
type_prob = ops.Softmax()(q_type).asnumpy()
answer_dict_ = convert_to_tokens(example_dict,
feature_dict,
batch['ids'],
y1.asnumpy().tolist(),
y2.asnumpy().tolist(),
type_prob,
tokenizer,
sent_logit.asnumpy(),
sent_names,
unique_ids)
for q_id in answer_dict_:
answer_dict[q_id] = answer_dict_[q_id]
for q_id in answer_dict:
res = answer_dict[q_id]
answer_text_ = res[0]
sent_ = res[1]
sent_names_ = res[2]
new_answer_dict[q_id] = answer_text_
predict_support_np = ops.Sigmoid()(Tensor(sent_, mstype.float32)).asnumpy()
for j in range(predict_support_np.shape[0]):
if j >= len(sent_names_):
break
if predict_support_np[j] > sp_threshold:
total_sp_dict[q_id].append(sent_names_[j])
for _id in total_sp_dict:
_sent_names = total_sp_dict[_id]
for para in _sent_names:
title = make_wiki_id(para[0], 0)
para_original_title = doc_db.get_doc_info(title)[-1]
para[0] = para_original_title
new_total_sp_dict[_id].append(para)
prediction = {'answer': new_answer_dict,
'sp': new_total_sp_dict}
with open(reader_result_file, 'w') as f:
json.dump(prediction, f, indent=4)
t2 = time()
print(f"reader cost time: {t2-t1} s")

File diff suppressed because it is too large Load Diff

@ -0,0 +1,183 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define a data generator"""
import gzip
import pickle
import random
import numpy as np
random.seed(42)
np.random.seed(42)
class DataGenerator:
"""data generator for reranker and reader"""
def __init__(self, feature_file_path, example_file_path, batch_size, seq_len,
para_limit=None, sent_limit=None, task_type=None):
"""init function"""
self.example_ptr = 0
self.bsz = batch_size
self.seq_length = seq_len
self.para_limit = para_limit
self.sent_limit = sent_limit
self.task_type = task_type
self.feature_file_path = feature_file_path
self.example_file_path = example_file_path
self.features = self.load_features()
self.examples = self.load_examples()
self.feature_dict = self.get_feature_dict()
self.example_dict = self.get_example_dict()
self.features = self.padding_feature(self.features, self.bsz)
def load_features(self):
"""load features from feature file"""
with gzip.open(self.feature_file_path, 'rb') as fin:
features = pickle.load(fin)
print("load features successful !!!")
return features
def padding_feature(self, features, bsz):
"""padding features as multiples of batch size"""
padding_num = ((len(features) // bsz + 1) * bsz - len(features))
print(f"features padding num is {padding_num}")
new_features = features + features[:padding_num]
return new_features
def load_examples(self):
"""laod examples from file"""
if self.example_file_path:
with gzip.open(self.example_file_path, 'rb') as fin:
examples = pickle.load(fin)
print("load examples successful !!!")
return examples
return {}
def get_feature_dict(self):
"""build a feature dict"""
return {f.unique_id: f for f in self.features}
def get_example_dict(self):
"""build a example dict"""
if self.example_file_path:
return {e.unique_id: e for e in self.examples}
return {}
def common_process_single_case(self, i, case, context_idxs, context_mask, segment_idxs, ids, path, unique_ids):
"""common process for a single case"""
context_idxs[i] = np.array(case.doc_input_ids)
context_mask[i] = np.array(case.doc_input_mask)
segment_idxs[i] = np.array(case.doc_segment_ids)
ids.append(case.qas_id)
path.append(case.path)
unique_ids.append(case.unique_id)
return context_idxs, context_mask, segment_idxs, ids, path, unique_ids
def reader_process_single_case(self, i, case, sent_names, square_mask, query_mapping, ques_start_mapping,
para_start_mapping, sent_end_mapping):
"""process for a single case about reader"""
sent_names.append(case.sent_names)
prev_position = None
for cur_position, token_id in enumerate(case.doc_input_ids):
if token_id >= 30000:
if prev_position:
square_mask[i, prev_position + 1: cur_position, prev_position + 1: cur_position] = 1.0
prev_position = cur_position
if case.sent_spans:
for j in range(case.sent_spans[0][0] - 1):
query_mapping[i, j] = 1
ques_start_mapping[i, 0, 1] = 1
for j, para_span in enumerate(case.para_spans[:self.para_limit]):
start, end, _ = para_span
if start <= end:
para_start_mapping[i, j, start] = 1
for j, sent_span in enumerate(case.sent_spans[:self.sent_limit]):
start, end = sent_span
if start <= end:
end = min(end, self.seq_length - 1)
sent_end_mapping[i, j, end] = 1
return sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping
def __iter__(self):
"""iteration function"""
while True:
if self.example_ptr >= len(self.features):
break
start_id = self.example_ptr
cur_bsz = min(self.bsz, len(self.features) - start_id)
cur_batch = self.features[start_id: start_id + cur_bsz]
# BERT input
context_idxs = np.zeros((cur_bsz, self.seq_length))
context_mask = np.zeros((cur_bsz, self.seq_length))
segment_idxs = np.zeros((cur_bsz, self.seq_length))
# others
ids = []
path = []
unique_ids = []
if self.task_type == "reader":
# Mappings
ques_start_mapping = np.zeros((cur_bsz, 1, self.seq_length))
query_mapping = np.zeros((cur_bsz, self.seq_length))
para_start_mapping = np.zeros((cur_bsz, self.para_limit, self.seq_length))
sent_end_mapping = np.zeros((cur_bsz, self.sent_limit, self.seq_length))
square_mask = np.zeros((cur_bsz, self.seq_length, self.seq_length))
sent_names = []
for i, case in enumerate(cur_batch):
context_idxs, context_mask, segment_idxs, ids, path, unique_ids = \
self.common_process_single_case(i, case, context_idxs, context_mask, segment_idxs, ids, path,
unique_ids)
if self.task_type == "reader":
sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping = \
self.reader_process_single_case(i, case, sent_names, square_mask, query_mapping,
ques_start_mapping, para_start_mapping, sent_end_mapping)
self.example_ptr += cur_bsz
if self.task_type == "reranker":
yield {
"context_idxs": context_idxs,
"context_mask": context_mask,
"segment_idxs": segment_idxs,
"ids": ids,
"unique_ids": unique_ids,
"path": path
}
elif self.task_type == "reader":
yield {
"context_idxs": context_idxs,
"context_mask": context_mask,
"segment_idxs": segment_idxs,
"query_mapping": query_mapping,
"para_start_mapping": para_start_mapping,
"sent_end_mapping": sent_end_mapping,
"square_mask": square_mask,
"ques_start_mapping": ques_start_mapping,
"ids": ids,
"unique_ids": unique_ids,
"sent_names": sent_names,
"path": path
}
else:
print(f"data generator received a error type: {self.task_type} !!!")

File diff suppressed because it is too large Load Diff

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""downstream Model for reranker"""
import numpy as np
from mindspore import nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
class Rerank_Downstream(nn.Cell):
"""Downstream model for rerank"""
def __init__(self):
"""init function"""
super(Rerank_Downstream, self).__init__()
self.dense_0 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True)
self.relu_1 = nn.ReLU()
self.reducemean_2 = P.ReduceMean(keep_dims=True)
self.sub_3 = P.Sub()
self.sub_4 = P.Sub()
self.pow_5 = P.Pow()
self.pow_5_input_weight = 2.0
self.reducemean_6 = P.ReduceMean(keep_dims=True)
self.add_7 = P.Add()
self.add_7_bias = 9.999999960041972e-13
self.sqrt_8 = P.Sqrt()
self.div_9 = P.Div()
self.mul_10 = P.Mul()
self.mul_10_w = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
self.add_11 = P.Add()
self.add_11_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
self.dense_12 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True)
def construct(self, x):
"""construct function"""
opt_dense_0 = self.dense_0(x)
opt_relu_1 = self.relu_1(opt_dense_0)
opt_reducemean_2 = self.reducemean_2(opt_relu_1, -1)
opt_sub_3 = self.sub_3(opt_relu_1, opt_reducemean_2)
opt_sub_4 = self.sub_4(opt_relu_1, opt_reducemean_2)
opt_pow_5 = self.pow_5(opt_sub_3, self.pow_5_input_weight)
opt_reducemean_6 = self.reducemean_6(opt_pow_5, -1)
opt_add_7 = self.add_7(opt_reducemean_6, self.add_7_bias)
opt_sqrt_8 = self.sqrt_8(opt_add_7)
opt_div_9 = self.div_9(opt_sub_4, opt_sqrt_8)
opt_mul_10 = self.mul_10(self.mul_10_w, opt_div_9)
opt_add_11 = self.add_11(opt_mul_10, self.add_11_bias)
opt_dense_12 = self.dense_12(opt_add_11)
return opt_dense_12

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reranker Model"""
import mindspore.nn as nn
from mindspore import load_checkpoint, load_param_into_net
from src.rerank_albert_xxlarge import Rerank_Albert
from src.rerank_downstream import Rerank_Downstream
class Reranker(nn.Cell):
"""Reranker model"""
def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
"""init function"""
super(Reranker, self).__init__(auto_prefix=False)
self.encoder = Rerank_Albert(batch_size)
param_dict = load_checkpoint(encoder_ck_file)
not_load_params_1 = load_param_into_net(self.encoder, param_dict)
print(f"not loaded albert: {not_load_params_1}")
self.no_answer_mlp = Rerank_Downstream()
param_dict = load_checkpoint(downstream_ck_file)
not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict)
print(f"not loaded downstream: {not_load_params_2}")
def construct(self, input_ids, attn_mask, token_type_ids):
"""construct function"""
state = self.encoder(input_ids, attn_mask, token_type_ids)
state = state[:, 0, :]
no_answer = self.no_answer_mlp(state)
return no_answer

@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""execute reranker"""
import json
import random
from collections import defaultdict
from time import time
from tqdm import tqdm
import numpy as np
from mindspore import Tensor, ops
from mindspore import dtype as mstype
from src.rerank_and_reader_data_generator import DataGenerator
from src.reranker import Reranker
def rerank(args):
"""rerank function"""
rerank_feature_file = args.rerank_feature_file
rerank_result_file = args.rerank_result_file
encoder_ck_file = args.rerank_encoder_ck_file
downstream_ck_file = args.rerank_downstream_ck_file
seed = args.seed
seq_len = args.seq_len
batch_size = args.rerank_batch_size
random.seed(seed)
np.random.seed(seed)
t1 = time()
generator = DataGenerator(feature_file_path=rerank_feature_file,
example_file_path=None,
batch_size=batch_size, seq_len=seq_len,
task_type="reranker")
gather_dict = defaultdict(lambda: defaultdict(list))
reranker = Reranker(batch_size=batch_size,
encoder_ck_file=encoder_ck_file,
downstream_ck_file=downstream_ck_file)
print("start re-ranking ...")
for _, batch in tqdm(enumerate(generator)):
input_ids = Tensor(batch["context_idxs"], mstype.int32)
attn_mask = Tensor(batch["context_mask"], mstype.int32)
token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
no_answer = reranker(input_ids, attn_mask, token_type_ids)
no_answer_prob = ops.Softmax()(no_answer).asnumpy()
no_answer_prob = no_answer_prob[:, 0]
for i in range(len(batch['ids'])):
qas_id = batch['ids'][i]
gather_dict[qas_id][no_answer_prob[i]].append(batch['unique_ids'][i])
gather_dict[qas_id][no_answer_prob[i]].append(batch['path'][i])
rerank_result = {}
for qas_id in tqdm(gather_dict, desc="get top1 path from re-rank result"):
all_paths = gather_dict[qas_id]
all_paths = sorted(all_paths.items(), key=lambda item: item[0])
assert qas_id not in rerank_result
rerank_result[qas_id] = all_paths[0][1]
with open(rerank_result_file, 'w') as f:
json.dump(rerank_result, f)
t2 = time()
print(f"re-rank cost time: {t2-t1} s")
Loading…
Cancel
Save