commit
eca821227d
@ -0,0 +1,164 @@
|
||||
# Contents
|
||||
|
||||
- [Thinking Path Re-Ranker](#thinking-path-re-ranker)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Description of random situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Thinking Path Re-Ranker](#contents)
|
||||
|
||||
Thinking Path Re-Ranker(TPRR) was proposed in 2021 by Huawei Poisson Lab & Parallel Distributed Computing Lab. By incorporating the
|
||||
retriever, reranker and reader modules, TPRR shows excellent performance on open-domain multi-hop question answering. Moreover, TPRR has won
|
||||
the first place in the current HotpotQA official leaderboard. This is a example of evaluation of TPRR with HotPotQA dataset in MindSpore. More
|
||||
importantly, this is the first open source version for TPRR.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Specially, TPRR contains three main modules. The first is retriever, which generate document sequences of each hop iteratively. The second
|
||||
is reranker for selecting the best path from candidate paths generated by retriever. The last one is reader for extracting answer spans.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
The retriever dataset consists of three parts:
|
||||
Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyperlinks.
|
||||
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
|
||||
dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia
|
||||
through TF-IDF.
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## [Mixed Precision](#contents)
|
||||
|
||||
To ultilize the strong computation power of Ascend chip, and accelerate the evaluation process, the mixed evaluation method is used. MindSpore
|
||||
is able to cope with FP32 inputs and FP16 operators. In TPRR example, the model is set to FP16 mode for the matmul calculation part.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware (Ascend)
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```python
|
||||
# run evaluation example with HotPotQA dev dataset
|
||||
sh run_eval_ascend.sh
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─tprr
|
||||
├─README.md
|
||||
├─scripts
|
||||
| ├─run_eval_ascend.sh # Launch evaluation in ascend
|
||||
|
|
||||
├─src
|
||||
| ├─config.py # Evaluation configurations
|
||||
| ├─onehop.py # Onehop model
|
||||
| ├─onehop_bert.py # Onehop bert model
|
||||
| ├─process_data.py # Data preprocessing
|
||||
| ├─twohop.py # Twohop model
|
||||
| ├─twohop_bert.py # Twohop bert model
|
||||
| └─utils.py # Utils for evaluation
|
||||
|
|
||||
└─retriever_eval.py # Evaluation net for retriever
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for evaluation can be set in config.py.
|
||||
|
||||
- config for TPRR retriever dataset
|
||||
|
||||
```python
|
||||
"q_len": 64, # Max query length
|
||||
"d_len": 192, # Max doc length
|
||||
"s_len": 448, # Max sequence length
|
||||
"in_len": 768, # Input dim
|
||||
"out_len": 1, # Output dim
|
||||
"num_docs": 500, # Num of docs
|
||||
"topk": 8, # Top k
|
||||
"onehop_num": 8 # Num of onehop doc as twohop neighbor
|
||||
```
|
||||
|
||||
config.py for more configuration.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- Evaluation on Ascend
|
||||
|
||||
```python
|
||||
sh run_eval_ascend.sh
|
||||
```
|
||||
|
||||
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
|
||||
followings in log.
|
||||
|
||||
```python
|
||||
###step###: 0
|
||||
val: 0
|
||||
count: 1
|
||||
true count: 0
|
||||
PEM: 0.0
|
||||
|
||||
...
|
||||
###step###: 7396
|
||||
val:6796
|
||||
count:7397
|
||||
true count: 6924
|
||||
PEM: 0.9187508449371367
|
||||
true top8 PEM: 0.9815135759676488
|
||||
evaluation time (h): 20.155506462653477
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameter | BGCF Ascend |
|
||||
| ------------------------------ | ---------------------------- |
|
||||
| Model Version | Inception V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 03/12/2021(month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | HotPotQA |
|
||||
| Batch_size | 1 |
|
||||
| Output | inference path |
|
||||
| PEM | 0.9188 |
|
||||
|
||||
# [Description of random situation](#contents)
|
||||
|
||||
No random situation for evaluation.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
@ -0,0 +1,180 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Retriever Evaluation.
|
||||
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.onehop import OneHopBert
|
||||
from src.twohop import TwoHopBert
|
||||
from src.process_data import DataGen
|
||||
from src.onehop_bert import ModelOneHop
|
||||
from src.twohop_bert import ModelTwoHop
|
||||
from src.config import ThinkRetrieverConfig
|
||||
from src.utils import read_query, split_queries, get_new_title, get_raw_title, save_json
|
||||
|
||||
|
||||
def eval_output(out_2, last_out, path_raw, gold_path, val, true_count):
|
||||
"""evaluation output"""
|
||||
y_pred_raw = out_2.asnumpy()
|
||||
last_out_raw = last_out.asnumpy()
|
||||
path = []
|
||||
y_pred = []
|
||||
last_out_list = []
|
||||
topk_titles = []
|
||||
index_list_raw = np.argsort(y_pred_raw)
|
||||
for index_r in index_list_raw[::-1]:
|
||||
tag = 1
|
||||
for raw_path in path:
|
||||
if path_raw[index_r][0] in raw_path and path_raw[index_r][1] in raw_path:
|
||||
tag = 0
|
||||
break
|
||||
if tag:
|
||||
path.append(path_raw[index_r])
|
||||
y_pred.append(y_pred_raw[index_r])
|
||||
last_out_list.append(last_out_raw[index_r])
|
||||
index_list = np.argsort(y_pred)
|
||||
for path_index in index_list:
|
||||
if gold_path[0] in path[path_index] and gold_path[1] in path[path_index]:
|
||||
true_count += 1
|
||||
break
|
||||
for path_index in index_list[-8:][::-1]:
|
||||
topk_titles.append(list(path[path_index]))
|
||||
for path_index in index_list[-8:]:
|
||||
if gold_path[0] in path[path_index] and gold_path[1] in path[path_index]:
|
||||
val += 1
|
||||
break
|
||||
return val, true_count, topk_titles
|
||||
|
||||
|
||||
def evaluation():
|
||||
"""evaluation"""
|
||||
print('********************** loading corpus ********************** ')
|
||||
s_lc = time.time()
|
||||
data_generator = DataGen(config)
|
||||
queries = read_query(config)
|
||||
print("loading corpus time (h):", (time.time() - s_lc) / 3600)
|
||||
print('********************** loading model ********************** ')
|
||||
s_lm = time.time()
|
||||
|
||||
model_onehop_bert = ModelOneHop()
|
||||
param_dict = load_checkpoint(config.onehop_bert_path)
|
||||
load_param_into_net(model_onehop_bert, param_dict)
|
||||
model_twohop_bert = ModelTwoHop()
|
||||
param_dict2 = load_checkpoint(config.twohop_bert_path)
|
||||
load_param_into_net(model_twohop_bert, param_dict2)
|
||||
onehop = OneHopBert(config, model_onehop_bert)
|
||||
twohop = TwoHopBert(config, model_twohop_bert)
|
||||
|
||||
print("loading model time (h):", (time.time() - s_lm) / 3600)
|
||||
print('********************** evaluation ********************** ')
|
||||
s_tr = time.time()
|
||||
|
||||
f_dev = open(config.dev_path, 'rb')
|
||||
dev_data = json.load(f_dev)
|
||||
q_gold = {}
|
||||
q_2id = {}
|
||||
for onedata in dev_data:
|
||||
if onedata["question"] not in q_gold:
|
||||
q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']]
|
||||
q_2id[onedata["question"]] = onedata['_id']
|
||||
val, true_count, count, step = 0, 0, 0, 0
|
||||
batch_queries = split_queries(config, queries)[:-1]
|
||||
output_path = []
|
||||
for _, batch in enumerate(batch_queries):
|
||||
print("###step###: ", step)
|
||||
query = batch[0]
|
||||
temp_dict = {}
|
||||
temp_dict['q_id'] = q_2id[query]
|
||||
temp_dict['question'] = query
|
||||
gold_path = q_gold[query]
|
||||
input_ids_1, token_type_ids_1, input_mask_1 = data_generator.convert_onehop_to_features(batch)
|
||||
start = 0
|
||||
TOTAL = len(input_ids_1)
|
||||
split_chunk = 8
|
||||
while start < TOTAL:
|
||||
end = min(start + split_chunk - 1, TOTAL - 1)
|
||||
chunk_len = end - start + 1
|
||||
input_ids_1_ = input_ids_1[start:start + chunk_len]
|
||||
input_ids_1_ = Tensor(input_ids_1_, mstype.int32)
|
||||
token_type_ids_1_ = token_type_ids_1[start:start + chunk_len]
|
||||
token_type_ids_1_ = Tensor(token_type_ids_1_, mstype.int32)
|
||||
input_mask_1_ = input_mask_1[start:start + chunk_len]
|
||||
input_mask_1_ = Tensor(input_mask_1_, mstype.int32)
|
||||
cls_out = onehop(input_ids_1_, token_type_ids_1_, input_mask_1_)
|
||||
if start == 0:
|
||||
out = cls_out
|
||||
else:
|
||||
out = P.Concat(0)((out, cls_out))
|
||||
start = end + 1
|
||||
out = P.Squeeze(1)(out)
|
||||
onehop_prob, onehop_index = P.TopK(sorted=True)(out, config.topk)
|
||||
onehop_prob = P.Softmax()(onehop_prob)
|
||||
sample, path_raw, last_out = data_generator.get_samples(query, onehop_index, onehop_prob)
|
||||
input_ids_2, token_type_ids_2, input_mask_2 = data_generator.convert_twohop_to_features(sample)
|
||||
start_2 = 0
|
||||
TOTAL_2 = len(input_ids_2)
|
||||
split_chunk = 8
|
||||
while start_2 < TOTAL_2:
|
||||
end_2 = min(start_2 + split_chunk - 1, TOTAL_2 - 1)
|
||||
chunk_len = end_2 - start_2 + 1
|
||||
input_ids_2_ = input_ids_2[start_2:start_2 + chunk_len]
|
||||
input_ids_2_ = Tensor(input_ids_2_, mstype.int32)
|
||||
token_type_ids_2_ = token_type_ids_2[start_2:start_2 + chunk_len]
|
||||
token_type_ids_2_ = Tensor(token_type_ids_2_, mstype.int32)
|
||||
input_mask_2_ = input_mask_2[start_2:start_2 + chunk_len]
|
||||
input_mask_2_ = Tensor(input_mask_2_, mstype.int32)
|
||||
cls_out = twohop(input_ids_2_, token_type_ids_2_, input_mask_2_)
|
||||
if start_2 == 0:
|
||||
out_2 = cls_out
|
||||
else:
|
||||
out_2 = P.Concat(0)((out_2, cls_out))
|
||||
start_2 = end_2 + 1
|
||||
out_2 = P.Softmax()(out_2)
|
||||
last_out = Tensor(last_out, mstype.float32)
|
||||
out_2 = P.Mul()(out_2, last_out)
|
||||
val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count)
|
||||
temp_dict['topk_titles'] = topk_titles
|
||||
output_path.append(temp_dict)
|
||||
count += 1
|
||||
print("val:", val)
|
||||
print("count:", count)
|
||||
print("true count:", true_count)
|
||||
if count:
|
||||
print("PEM:", val / count)
|
||||
if true_count:
|
||||
print("true top8 PEM:", val / true_count)
|
||||
step += 1
|
||||
save_json(output_path, config.save_path, config.save_name)
|
||||
print("evaluation time (h):", (time.time() - s_tr) / 3600)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ThinkRetrieverConfig()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target='Ascend',
|
||||
device_id=config.device_id,
|
||||
save_graphs=False)
|
||||
evaluation()
|
@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# eval script
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation"
|
||||
|
||||
python retriever_eval.py > log.txt 2>&1 &
|
||||
|
||||
cd ..
|
@ -0,0 +1,46 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Retriever Config.
|
||||
|
||||
"""
|
||||
import argparse
|
||||
|
||||
|
||||
def ThinkRetrieverConfig():
|
||||
"""retriever config"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--q_len", type=int, default=64, help="max query len")
|
||||
parser.add_argument("--d_len", type=int, default=192, help="max doc len")
|
||||
parser.add_argument("--s_len", type=int, default=448, help="max seq len")
|
||||
parser.add_argument("--in_len", type=int, default=768, help="in len")
|
||||
parser.add_argument("--out_len", type=int, default=1, help="out len")
|
||||
parser.add_argument("--num_docs", type=int, default=500, help="docs num")
|
||||
parser.add_argument("--topk", type=int, default=8, help="top num")
|
||||
parser.add_argument("--onehop_num", type=int, default=8, help="onehop num")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
|
||||
parser.add_argument("--save_path", type=str, default='./', help='path of output')
|
||||
parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path")
|
||||
parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path")
|
||||
parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json',
|
||||
help="dev path")
|
||||
parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path")
|
||||
parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path")
|
||||
parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path")
|
||||
parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path")
|
||||
parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path")
|
||||
return parser.parse_args()
|
@ -0,0 +1,58 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
One Hop Model.
|
||||
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
class Model(nn.Cell):
|
||||
"""mlp model"""
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.tanh_0 = nn.Tanh()
|
||||
self.dense_1 = nn.Dense(in_channels=768, out_channels=1, has_bias=True)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct function"""
|
||||
opt_tanh_0 = self.tanh_0(x)
|
||||
opt_dense_1 = self.dense_1(opt_tanh_0)
|
||||
return opt_dense_1
|
||||
|
||||
|
||||
class OneHopBert(nn.Cell):
|
||||
"""onehop model"""
|
||||
def __init__(self, config, network):
|
||||
super(OneHopBert, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.mlp = Model()
|
||||
param_dict = load_checkpoint(config.onehop_mlp_path)
|
||||
load_param_into_net(self.mlp, param_dict)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
token_type_id,
|
||||
input_mask):
|
||||
"""construct function"""
|
||||
out = self.network(input_ids, token_type_id, input_mask)
|
||||
out = self.mlp(out)
|
||||
out = self.cast(out, mstype.float32)
|
||||
return out
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,60 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Two Hop Model.
|
||||
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
class Model(nn.Cell):
|
||||
"""mlp model"""
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.tanh_0 = nn.Tanh()
|
||||
self.dense_1 = nn.Dense(in_channels=768, out_channels=1, has_bias=True)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct function"""
|
||||
opt_tanh_0 = self.tanh_0(x)
|
||||
opt_dense_1 = self.dense_1(opt_tanh_0)
|
||||
return opt_dense_1
|
||||
|
||||
|
||||
class TwoHopBert(nn.Cell):
|
||||
"""two hop model"""
|
||||
def __init__(self, config, network):
|
||||
super(TwoHopBert, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.mlp = Model()
|
||||
param_dict = load_checkpoint(config.twohop_mlp_path)
|
||||
load_param_into_net(self.mlp, param_dict)
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
token_type_id,
|
||||
input_mask):
|
||||
"""construct function"""
|
||||
out = self.network(input_ids, token_type_id, input_mask)
|
||||
out = self.mlp(out)
|
||||
out = self.cast(out, mstype.float32)
|
||||
out = self.reshape(out, (-1,))
|
||||
return out
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,62 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Retriever Utils.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import unicodedata
|
||||
import pickle as pkl
|
||||
|
||||
|
||||
def normalize(text):
|
||||
"""normalize text"""
|
||||
text = unicodedata.normalize('NFD', text)
|
||||
return text[0].capitalize() + text[1:]
|
||||
|
||||
|
||||
def read_query(config):
|
||||
"""get query data"""
|
||||
with open(config.dev_data_path, 'rb') as f:
|
||||
temp_dic = pkl.load(f, encoding='gbk')
|
||||
queries = []
|
||||
for item in temp_dic:
|
||||
queries.append(temp_dic[item]["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def split_queries(config, queries):
|
||||
batch_size = config.batch_size
|
||||
batch_queries = [queries[i:i + batch_size] for i in range(0, len(queries), batch_size)]
|
||||
return batch_queries
|
||||
|
||||
|
||||
def save_json(obj, path, name):
|
||||
with open(path + name, "w") as f:
|
||||
return json.dump(obj, f)
|
||||
|
||||
def get_new_title(title):
|
||||
"""get new title"""
|
||||
if title[-2:] == "_0":
|
||||
return normalize(title[:-2]) + "_0"
|
||||
return normalize(title) + "_0"
|
||||
|
||||
|
||||
def get_raw_title(title):
|
||||
"""get raw title"""
|
||||
if title[-2:] == "_0":
|
||||
return title[:-2]
|
||||
return title
|
Loading…
Reference in new issue