@ -19,11 +19,12 @@ Bert finetune and evaluation script.
import os
import argparse
import time
from src . bert_for_finetune import BertFinetuneCell , BertNER
from src . finetune_eval_config import optimizer_cfg , bert_net_cfg
from src . dataset import create_ner_dataset
from src . utils import make_directory , LossCallBack , LoadNewestCkpt , BertLearningRate , convert_labels_to_index
from src . assessment_method import Accuracy , F1 , MCC , Spearman_Correlation , SpanF1
from src . assessment_method import Accuracy , F1 , MCC , Spearman_Correlation
import mindspore . common . dtype as mstype
from mindspore import context
from mindspore import log as logger
@ -79,17 +80,22 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
netwithgrads = BertFinetuneCell ( network , optimizer = optimizer , scale_update_cell = update_cell )
model = Model ( netwithgrads )
callbacks = [ TimeMonitor ( dataset . get_dataset_size ( ) ) , LossCallBack ( dataset . get_dataset_size ( ) ) , ckpoint_cb ]
train_begin = time . time ( )
model . train ( epoch_num , dataset , callbacks = callbacks )
train_end = time . time ( )
print ( " latency: {:.6f} s " . format ( train_end - train_begin ) )
def eval_result_print ( assessment_method = " accuracy " , callback = None ) :
""" print eval result """
if assessment_method == " accuracy " :
print ( " acc_num {} , total_num {} , accuracy {:.6f} " . format ( callback . acc_num , callback . total_num ,
callback . acc_num / callback . total_num ) )
elif assessment_method in ( " f1 " , " spanf1 " ) :
elif assessment_method == " bf1 " :
print ( " Precision {:.6f} " . format ( callback . TP / ( callback . TP + callback . FP ) ) )
print ( " Recall {:.6f} " . format ( callback . TP / ( callback . TP + callback . FN ) ) )
print ( " F1 {:.6f} " . format ( 2 * callback . TP / ( 2 * callback . TP + callback . FP + callback . FN ) ) )
elif assessment_method == " mf1 " :
print ( " F1 {:.6f} " . format ( callback . eval ( ) [ 0 ] ) )
elif assessment_method == " mcc " :
print ( " MCC {:.6f} " . format ( callback . cal ( ) ) )
elif assessment_method == " spearman_correlation " :
@ -116,10 +122,10 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met
else :
if assessment_method == " accuracy " :
callback = Accuracy ( )
elif assessment_method == " f1" :
elif assessment_method == " b f1" :
callback = F1 ( ( use_crf . lower ( ) == " true " ) , num_class )
elif assessment_method == " span f1" :
callback = Span F1( ( use_crf . lower ( ) == " true " ) , tag_to_index )
elif assessment_method == " m f1" :
callback = F1( ( use_crf . lower ( ) == " true " ) , num_labels= num_class , mode = " MultiLabel " )
elif assessment_method == " mcc " :
callback = MCC ( )
elif assessment_method == " spearman_correlation " :
@ -145,8 +151,8 @@ def parse_args():
parser = argparse . ArgumentParser ( description = " run ner " )
parser . add_argument ( " --device_target " , type = str , default = " Ascend " , choices = [ " Ascend " , " GPU " ] ,
help = " Device type, default is Ascend " )
parser . add_argument ( " --assessment_method " , type = str , default = " F1" , choices = [ " F1" , " clue_benchmark " , " Span F1" ] ,
help = " assessment_method include: [ F1, clue_benchmark, SpanF1], default is F1" )
parser . add_argument ( " --assessment_method " , type = str , default = " B F1" , choices = [ " B F1" , " clue_benchmark " , " M F1" ] ,
help = " assessment_method include: [ BF1, clue_benchmark, MF1], default is B F1" )
parser . add_argument ( " --do_train " , type = str , default = " false " , choices = [ " true " , " false " ] ,
help = " Eable train, default is false " )
parser . add_argument ( " --do_eval " , type = str , default = " false " , choices = [ " true " , " false " ] ,
@ -231,6 +237,12 @@ def run_ner():
assessment_method = assessment_method , data_file_path = args_opt . train_data_file_path ,
schema_file_path = args_opt . schema_file_path , dataset_format = args_opt . dataset_format ,
do_shuffle = ( args_opt . train_data_shuffle . lower ( ) == " true " ) )
print ( " ============================================================== " )
print ( " processor_name: {} " . format ( args_opt . device_target ) )
print ( " test_name: BERT Finetune Training " )
print ( " model_name: {} " . format ( " BERT+MLP+CRF " if args_opt . use_crf . lower ( ) == " true " else " BERT + MLP " ) )
print ( " batch_size: {} " . format ( args_opt . train_batch_size ) )
do_train ( ds , netwithloss , load_pretrain_checkpoint_path , save_finetune_checkpoint_path , epoch_num )
if args_opt . do_eval . lower ( ) == " true " :
@ -245,7 +257,7 @@ def run_ner():
ds = create_ner_dataset ( batch_size = args_opt . eval_batch_size , repeat_count = 1 ,
assessment_method = assessment_method , data_file_path = args_opt . eval_data_file_path ,
schema_file_path = args_opt . schema_file_path , dataset_format = args_opt . dataset_format ,
do_shuffle = ( args_opt . eval_data_shuffle . lower ( ) == " true " ) )
do_shuffle = ( args_opt . eval_data_shuffle . lower ( ) == " true " ) , drop_remainder = False )
do_eval ( ds , BertNER , args_opt . use_crf , number_labels , assessment_method ,
args_opt . eval_data_file_path , load_finetune_checkpoint_path , args_opt . vocab_file_path ,
args_opt . label_file_path , tag_to_index , args_opt . eval_batch_size )