!9258 add squad for bert
From: @yoonlee666 Reviewed-by: @c_34,@guoqi1024 Signed-off-by: @c_34pull/9258/MERGE
commit
989744c61a
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,97 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""evaluation script for SQuAD v1.1"""
|
||||
|
||||
from collections import Counter
|
||||
import string
|
||||
import re
|
||||
import json
|
||||
import sys
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
"""calculate f1 score"""
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
def evaluate(dataset, predictions):
|
||||
"""do evaluation"""
|
||||
f1 = exact_match = total = 0
|
||||
for article in dataset:
|
||||
for paragraph in article['paragraphs']:
|
||||
for qa in paragraph['qas']:
|
||||
total += 1
|
||||
if qa['id'] not in predictions:
|
||||
message = 'Unanswered question ' + qa['id'] + \
|
||||
' will receive score 0.'
|
||||
print(message, file=sys.stderr)
|
||||
continue
|
||||
ground_truths = list(map(lambda x: x['text'], qa['answers']))
|
||||
if not ground_truths:
|
||||
continue
|
||||
prediction = predictions[qa['id']]
|
||||
exact_match += metric_max_over_ground_truths(
|
||||
exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(
|
||||
f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
return {'exact_match': exact_match, 'f1': f1}
|
||||
|
||||
|
||||
def SQuad_postprocess(dataset_file, all_predictions, output_metrics="output.json"):
|
||||
with open(dataset_file) as ds:
|
||||
dataset_json = json.load(ds)
|
||||
dataset = dataset_json['data']
|
||||
re_json = evaluate(dataset, all_predictions)
|
||||
print(json.dumps(re_json))
|
||||
with open(output_metrics, 'w') as wr:
|
||||
wr.write(json.dumps(re_json))
|
Loading…
Reference in new issue