!6616 add tokenization and score file

Merge pull request !6616 from yoonlee666/token
pull/6616/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b309850036

@ -18,12 +18,11 @@ Bert finetune and evaluation script.
''' '''
import os import os
import json
import argparse import argparse
from src.bert_for_finetune import BertFinetuneCell, BertNER from src.bert_for_finetune import BertFinetuneCell, BertNER
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_ner_dataset from src.dataset import create_ner_dataset
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
@ -99,7 +98,7 @@ def eval_result_print(assessment_method="accuracy", callback=None):
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="", def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="",
load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None): load_checkpoint_path="", vocab_file="", label_file="", tag_to_index=None):
""" do eval """ """ do eval """
if load_checkpoint_path == "": if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!") raise ValueError("Finetune model missed, evaluation task must load finetune model!")
@ -114,7 +113,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth
if assessment_method == "clue_benchmark": if assessment_method == "clue_benchmark":
from src.cluener_evaluation import submit from src.cluener_evaluation import submit
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file) submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf,
label_file=label_file, tag_to_index=tag_to_index)
else: else:
if assessment_method == "accuracy": if assessment_method == "accuracy":
callback = Accuracy() callback = Accuracy()
@ -161,7 +161,7 @@ def parse_args():
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"], parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
help="Enable eval data shuffle, default is false") help="Enable eval data shuffle, default is false")
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark") parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark") parser.add_argument("--label_file_path", type=str, default="", help="label file path, used in clue benchmark")
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
@ -180,10 +180,10 @@ def parse_args():
raise ValueError("'eval_data_file_path' must be set when do evaluation task") raise ValueError("'eval_data_file_path' must be set when do evaluation task")
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "": if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "":
raise ValueError("'vocab_file_path' must be set to do clue benchmark") raise ValueError("'vocab_file_path' must be set to do clue benchmark")
if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "": if args_opt.use_crf.lower() == "true" and args_opt.label_file_path == "":
raise ValueError("'label2id_file_path' must be set to use crf") raise ValueError("'label_file_path' must be set to use crf")
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "": if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label_file_path == "":
raise ValueError("'label2id_file_path' must be set to do clue benchmark") raise ValueError("'label_file_path' must be set to do clue benchmark")
return args_opt return args_opt
@ -205,11 +205,12 @@ def run_ner():
bert_net_cfg.compute_type = mstype.float32 bert_net_cfg.compute_type = mstype.float32
else: else:
raise Exception("Target error, GPU or Ascend is supported.") raise Exception("Target error, GPU or Ascend is supported.")
label_list = []
tag_to_index = None with open(args_opt.label_file_path) as f:
for label in f:
label_list.append(label.strip())
tag_to_index = convert_labels_to_index(label_list)
if args_opt.use_crf.lower() == "true": if args_opt.use_crf.lower() == "true":
with open(args_opt.label2id_file_path) as json_file:
tag_to_index = json.load(json_file)
max_val = max(tag_to_index.values()) max_val = max(tag_to_index.values())
tag_to_index["<START>"] = max_val + 1 tag_to_index["<START>"] = max_val + 1
tag_to_index["<STOP>"] = max_val + 2 tag_to_index["<STOP>"] = max_val + 2
@ -240,7 +241,7 @@ def run_ner():
schema_file_path=args_opt.schema_file_path, schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index) load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label_file_path, tag_to_index)
if __name__ == "__main__": if __name__ == "__main__":
run_ner() run_ner()

@ -38,7 +38,7 @@ python ${PROJECT_DIR}/../run_ner.py \
--train_data_shuffle="true" \ --train_data_shuffle="true" \
--eval_data_shuffle="false" \ --eval_data_shuffle="false" \
--vocab_file_path="" \ --vocab_file_path="" \
--label2id_file_path="" \ --label_file_path="" \
--save_finetune_checkpoint_path="" \ --save_finetune_checkpoint_path="" \
--load_pretrain_checkpoint_path="" \ --load_pretrain_checkpoint_path="" \
--load_finetune_checkpoint_path="" \ --load_finetune_checkpoint_path="" \

@ -23,9 +23,9 @@ from src import tokenization
from src.sample_process import label_generation, process_one_example_p from src.sample_process import label_generation, process_one_example_p
from src.CRF import postprocess from src.CRF import postprocess
from src.finetune_eval_config import bert_net_cfg from src.finetune_eval_config import bert_net_cfg
from src.score import get_result
def process(model=None, text="", tokenizer_=None, use_crf="", tag_to_index=None, vocab=""):
def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
""" """
process text. process text.
""" """
@ -34,7 +34,7 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
res = [] res = []
ids = [] ids = []
for i in data: for i in data:
feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length) feature = process_one_example_p(tokenizer_, vocab, i, max_seq_len=bert_net_cfg.seq_length)
features.append(feature) features.append(feature)
input_ids, input_mask, token_type_id = feature input_ids, input_mask, token_type_id = feature
input_ids = Tensor(np.array(input_ids), mstype.int32) input_ids = Tensor(np.array(input_ids), mstype.int32)
@ -52,10 +52,10 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
ids = logits.asnumpy() ids = logits.asnumpy()
ids = np.argmax(ids, axis=-1) ids = np.argmax(ids, axis=-1)
ids = list(ids) ids = list(ids)
res = label_generation(text=text, probs=ids, label2id_file=label2id_file) res = label_generation(text=text, probs=ids, tag_to_index=tag_to_index)
return res return res
def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""): def submit(model=None, path="", vocab_file="", use_crf="", label_file="", tag_to_index=None):
""" """
submit task submit task
""" """
@ -66,8 +66,11 @@ def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
continue continue
oneline = json.loads(line.strip()) oneline = json.loads(line.strip())
res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_, res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_,
use_crf=use_crf, label2id_file=label2id_file) use_crf=use_crf, tag_to_index=tag_to_index, vocab=vocab_file)
print("text", oneline["text"])
print("res:", res)
data.append(json.dumps({"label": res}, ensure_ascii=False)) data.append(json.dumps({"label": res}, ensure_ascii=False))
open("ner_predict.json", "w").write("\n".join(data)) open("ner_predict.json", "w").write("\n".join(data))
labels = []
with open(label_file) as f:
for label in f:
labels.append(label.strip())
get_result(labels, "ner_predict.json", path)

@ -16,9 +16,9 @@
"""process txt""" """process txt"""
import re import re
import json from src.tokenization import convert_tokens_to_ids
def process_one_example_p(tokenizer, text, max_seq_len=128): def process_one_example_p(tokenizer, vocab, text, max_seq_len=128):
"""process one testline""" """process one testline"""
textlist = list(text) textlist = list(text)
tokens = [] tokens = []
@ -37,7 +37,7 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
segment_ids.append(0) segment_ids.append(0)
ntokens.append("[SEP]") ntokens.append("[SEP]")
segment_ids.append(0) segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(ntokens) input_ids = convert_tokens_to_ids(vocab, ntokens)
input_mask = [1] * len(input_ids) input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_len: while len(input_ids) < max_seq_len:
input_ids.append(0) input_ids.append(0)
@ -52,12 +52,12 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
feature = (input_ids, input_mask, segment_ids) feature = (input_ids, input_mask, segment_ids)
return feature return feature
def label_generation(text="", probs=None, label2id_file=""): def label_generation(text="", probs=None, tag_to_index=None):
"""generate label""" """generate label"""
data = [text] data = [text]
probs = [probs] probs = [probs]
result = [] result = []
label2id = json.loads(open(label2id_file).read()) label2id = tag_to_index
id2label = [k for k, v in label2id.items()] id2label = [k for k, v in label2id.items()]
for index, prob in enumerate(probs): for index, prob in enumerate(probs):

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Calculate average F1 score among labels.
"""
import json
def get_f1_score_for_each_label(pre_lines, gold_lines, label):
"""
Get F1 score for each label.
Args:
pre_lines: listed label info from pre_file.
gold_lines: listed label info from gold_file.
label:
Returns:
F1 score for this label.
"""
TP = 0
FP = 0
FN = 0
index = 0
while index < len(pre_lines):
pre_line = pre_lines[index].get(label, {})
gold_line = gold_lines[index].get(label, {})
for sample in pre_line:
if sample in gold_line:
TP += 1
else:
FP += 1
for sample in gold_line:
if sample not in pre_line:
FN += 1
index += 1
f1 = 2 * TP / (2 * TP + FP + FN)
return f1
def get_f1_score(labels, pre_file, gold_file):
"""
Get F1 scores for each label.
Args:
labels: list of labels.
pre_file: prediction file.
gold_file: ground truth file.
Returns:
average F1 score on all labels.
"""
pre_lines = [json.loads(line.strip())['label'] for line in open(pre_file) if line.strip()]
gold_lines = [json.loads(line.strip())['label'] for line in open(gold_file) if line.strip()]
if len(pre_lines) != len(gold_lines):
raise ValueError("pre file and gold file have different line count.")
f1_sum = 0
for label in labels:
f1 = get_f1_score_for_each_label(pre_lines, gold_lines, label)
print('label: %s, F1: %.6f' % (label, f1))
f1_sum += f1
return f1_sum/len(labels)
def get_result(labels, pre_file, gold_file):
avg = get_f1_score(labels, pre_file=pre_file, gold_file=gold_file)
print("avg F1: %.6f" % avg)

File diff suppressed because it is too large Load Diff

@ -19,6 +19,7 @@ Functional Cells used in Bert finetune and evaluation.
import os import os
import math import math
import collections
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import log as logger from mindspore import log as logger
@ -213,3 +214,19 @@ class BertLearningRate(LearningRateSchedule):
else: else:
lr = decay_lr lr = decay_lr
return lr return lr
def convert_labels_to_index(label_list):
"""
Convert label_list to indices for NER task.
"""
label2id = collections.OrderedDict()
label2id["O"] = 0
prefix = ["S_", "B_", "M_", "E_"]
index = 0
for label in label_list:
for pre in prefix:
index += 1
sub_label = pre + label
label2id[sub_label] = index
return label2id

Loading…
Cancel
Save