# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ Bert evaluation script. """ import os import numpy as np from evaluation_config import cfg, bert_net_cfg from utils import BertNER, BertCLS import mindspore.common.dtype as mstype from mindspore import context from mindspore.common.tensor import Tensor import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from CRF import postprocess from cluener_evaluation import submit from finetune_config import tag_to_index class Accuracy(): ''' calculate accuracy ''' def __init__(self): self.acc_num = 0 self.total_num = 0 def update(self, logits, labels): labels = labels.asnumpy() labels = np.reshape(labels, -1) logits = logits.asnumpy() logit_id = np.argmax(logits, axis=-1) self.acc_num += np.sum(labels == logit_id) self.total_num += len(labels) print("=========================accuracy is ", self.acc_num / self.total_num) class F1(): ''' calculate F1 score ''' def __init__(self): self.TP = 0 self.FP = 0 self.FN = 0 def update(self, logits, labels): ''' update F1 score ''' labels = labels.asnumpy() labels = np.reshape(labels, -1) if cfg.use_crf: backpointers, best_tag_id = logits best_path = postprocess(backpointers, best_tag_id) logit_id = [] for ele in best_path: logit_id.extend(ele) else: logits = logits.asnumpy() logit_id = np.argmax(logits, axis=-1) logit_id = np.reshape(logit_id, -1) pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)]) pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)]) self.TP += np.sum(pos_eva&pos_label) self.FP += np.sum(pos_eva&(~pos_label)) self.FN += np.sum((~pos_eva)&pos_label) def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): ''' get dataset ''' _ = distribute_file ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="label_ids", operations=type_cast_op) ds = ds.repeat(repeat_count) # apply shuffle operation buffer_size = 960 ds = ds.shuffle(buffer_size=buffer_size) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) return ds def bert_predict(Evaluation): ''' prediction function ''' devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) dataset = get_dataset(bert_net_cfg.batch_size, 1) if cfg.use_crf: net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, tag_to_index=tag_to_index, dropout_prob=0.0) else: net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels) net_for_pretraining.set_train(False) param_dict = load_checkpoint(cfg.finetune_ckpt) load_param_into_net(net_for_pretraining, param_dict) model = Model(net_for_pretraining) return model, dataset def test_eval(): ''' evaluation function ''' task_type = BertNER if cfg.task == "NER" else BertCLS model, dataset = bert_predict(task_type) if cfg.clue_benchmark: submit(model, cfg.data_file, bert_net_cfg.seq_length) else: callback = F1() if cfg.task == "NER" else Accuracy() columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in dataset.create_dict_iterator(): input_data = [] for i in columns_list: input_data.append(Tensor(data[i])) input_ids, input_mask, token_type_id, label_ids = input_data logits = model.predict(input_ids, input_mask, token_type_id, label_ids) callback.update(logits, label_ids) print("==============================================================") if cfg.task == "NER": print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FP))) else: print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, callback.acc_num / callback.total_num)) print("==============================================================") if __name__ == "__main__": num_labels = cfg.num_labels test_eval()