From c14c70747558a408e8331577ec0489b8ee2f7064 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Wed, 20 Jan 2021 14:12:52 +0800 Subject: [PATCH] Support CPU tinybert and ner task --- model_zoo/official/nlp/tinybert/README.md | 10 +- model_zoo/official/nlp/tinybert/export.py | 18 ++- .../nlp/tinybert/run_general_distill.py | 35 ++--- .../official/nlp/tinybert/run_task_distill.py | 127 +++++++++++------- .../nlp/tinybert/scripts/run_standalone_td.sh | 7 +- .../nlp/tinybert/src/assessment_method.py | 53 ++++++-- .../nlp/tinybert/src/tinybert_for_gd_td.py | 22 ++- .../nlp/tinybert/src/tinybert_model.py | 37 +++++ 8 files changed, 220 insertions(+), 89 deletions(-) diff --git a/model_zoo/official/nlp/tinybert/README.md b/model_zoo/official/nlp/tinybert/README.md index 829c359c63..9aa4fbf047 100644 --- a/model_zoo/official/nlp/tinybert/README.md +++ b/model_zoo/official/nlp/tinybert/README.md @@ -1,4 +1,4 @@ -# Contents +# Contents - [Contents](#contents) - [TinyBERT Description](#tinybert-description) @@ -197,8 +197,9 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN [--load_gd_ckpt_path LOAD_GD_CKPT_PATH] [--load_td1_ckpt_path LOAD_TD1_CKPT_PATH] [--train_data_dir TRAIN_DATA_DIR] - [--eval_data_dir EVAL_DATA_DIR] + [--eval_data_dir EVAL_DATA_DIR] [--task_type TASK_TYPE] [--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] + [--assessment_method ASSESSMENT_METHOD] options: --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" @@ -217,7 +218,9 @@ options: --load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is "" --train_data_dir path to train dataset directory: PATH, default is "" --eval_data_dir path to eval dataset directory: PATH, default is "" - --task_name classification task: "SST-2" | "QNLI" | "MNLI", default is "" + --task_type task type: "classification" | "ner", default is "classification" + --task_name classification or ner task: "SST-2" | "QNLI" | "MNLI" | "TNEWS", "CLUENER", default is "" + --assessment_method assessment method to do evaluation: acc | f1 --schema_dir path to schema.json file, PATH, default is "" --dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord ``` @@ -249,6 +252,7 @@ Parameters for optimizer: Parameters for bert network: seq_length length of input sequence: N, default is 128 vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522 + Usually, we use 21128 for CN vocabs and 30522 for EN vocabs according to the origin paper. Default is 30522 hidden_size size of bert encoder layers: N num_hidden_layers number of hidden layers: N num_attention_heads number of attention heads: N, default is 12 diff --git a/model_zoo/official/nlp/tinybert/export.py b/model_zoo/official/nlp/tinybert/export.py index 50aeeca37b..61847600d6 100644 --- a/model_zoo/official/nlp/tinybert/export.py +++ b/model_zoo/official/nlp/tinybert/export.py @@ -22,7 +22,7 @@ from mindspore import Tensor, context from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from src.td_config import td_student_net_cfg -from src.tinybert_model import BertModelCLS +from src.tinybert_model import BertModelCLS, BertModelNER parser = argparse.ArgumentParser(description='tinybert task distill') parser.add_argument("--device_id", type=int, default=0, help="Device id") @@ -31,7 +31,10 @@ parser.add_argument("--file_name", type=str, default="tinybert", help="output fi parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") -parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') +parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"], + help="The type of the task to train.") +parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"], + help="The name of the task to train.") args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) @@ -43,7 +46,9 @@ DEFAULT_SEQ_LENGTH = 128 DEFAULT_BS = 32 task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, "QNLI": {"num_labels": 2, "seq_length": 128}, - "MNLI": {"num_labels": 3, "seq_length": 128}} + "MNLI": {"num_labels": 3, "seq_length": 128}, + "TNEWS": {"num_labels": 15, "seq_length": 128}, + "CLUENER": {"num_labels": 10, "seq_length": 128}} class Task: """ @@ -68,8 +73,13 @@ if __name__ == '__main__': task = Task(args.task_name) td_student_net_cfg.seq_length = task.seq_length td_student_net_cfg.batch_size = DEFAULT_BS + if args.task_type == "classification": + eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") + elif args.task_type == "ner": + eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") + else: + raise ValueError(f"Not support task type: {args.task_type}") - eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") param_dict = load_checkpoint(args.ckpt_file) new_param_dict = {} for key, value in param_dict.items(): diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py index 2bfb3f417f..19a2f82b36 100644 --- a/model_zoo/official/nlp/tinybert/run_general_distill.py +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -33,14 +33,10 @@ from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell - - -def run_general_distill(): - """ - run general distill - """ +def get_argument(): + """Tinybert general distill argument parser.""" parser = argparse.ArgumentParser(description='tinybert general distill') - parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], help="Run distribute, default is false.") @@ -61,20 +57,21 @@ def run_general_distill(): parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type tfrecord/mindrecord, default is tfrecord") args_opt = parser.parse_args() + return args_opt +def run_general_distill(): + """ + run general distill + """ + args_opt = get_argument() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, + reserve_class_name_in_scope=False) if args_opt.device_target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) - elif args_opt.device_target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - else: - raise Exception("Target error, GPU or Ascend is supported.") - - context.set_context(reserve_class_name_in_scope=False) + context.set_context(device_id=args_opt.device_id) save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() @@ -104,6 +101,14 @@ def run_general_distill(): # and the loss scale is not necessary enable_loss_scale = False + if args_opt.device_target == "CPU": + logger.warning('CPU only support float32 temporarily, run with float32.') + bert_teacher_net_cfg.dtype = mstype.float32 + bert_teacher_net_cfg.compute_type = mstype.float32 + bert_student_net_cfg.dtype = mstype.float32 + bert_student_net_cfg.compute_type = mstype.float32 + enable_loss_scale = False + netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, teacher_ckpt=args_opt.load_teacher_ckpt_path, student_config=bert_student_net_cfg, diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index 430f27ac3b..e655a6cdff 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay from mindspore import log as logger from src.dataset import create_tinybert_dataset, DataType from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate -from src.assessment_method import Accuracy +from src.assessment_method import Accuracy, F1 from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell -from src.tinybert_model import BertModelCLS +from src.tinybert_model import BertModelCLS, BertModelNER _cur_dir = os.getcwd() td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') @@ -46,7 +46,7 @@ def parse_args(): parse args """ parser = argparse.ArgumentParser(description='tinybert task distill') - parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], + parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"], help="Do train task, default is true.") @@ -69,21 +69,46 @@ def parse_args(): parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") - parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], + parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"], + help="The type of the task to train.") + parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"], help="The name of the task to train.") + parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"], + help="assessment_method include: [accuracy, bf1, mf1], default is accuracy") parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type tfrecord/mindrecord, default is tfrecord") args = parser.parse_args() + if args.do_train.lower() != "true" and args.do_eval.lower() != "true": + raise ValueError("do train or do eval must have one be true, please confirm your config") + if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification": + raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification") + if args.task_name in ["CLUENER"] and args.task_type != "ner": + raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner") + if args.task_name in ["SST-2", "QNLI", "MNLI"] and \ + (td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522): + logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\ + "EN vocabs according to the origin paper.") + if args.task_name in ["TNEWS", "CLUENER"] and \ + (td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128): + logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \ + "EN vocabs according to the origin paper.") return args args_opt = parse_args() - +if args_opt.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD +elif args_opt.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD +else: + raise Exception("dataset format is not supported yet") DEFAULT_NUM_LABELS = 2 DEFAULT_SEQ_LENGTH = 128 task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, "QNLI": {"num_labels": 2, "seq_length": 128}, - "MNLI": {"num_labels": 3, "seq_length": 128}} + "MNLI": {"num_labels": 3, "seq_length": 128}, + "TNEWS": {"num_labels": 15, "seq_length": 128}, + "CLUENER": {"num_labels": 43, "seq_length": 128}} class Task: @@ -112,29 +137,15 @@ def run_predistill(): run predistill """ cfg = phase1_cfg - if args_opt.device_target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) - elif args_opt.device_target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - else: - raise Exception("Target error, GPU or Ascend is supported.") - context.set_context(reserve_class_name_in_scope=False) load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = args_opt.load_gd_ckpt_path netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, - is_training=True, task_type='classification', + is_training=True, task_type=args_opt.task_type, num_labels=task.num_labels, is_predistill=True) rank = 0 device_num = 1 - - if args_opt.dataset_type == "tfrecord": - dataset_type = DataType.TFRECORD - elif args_opt.dataset_type == "mindrecord": - dataset_type = DataType.MINDRECORD - else: - raise Exception("dataset format is not supported yet") dataset = create_tinybert_dataset('td', cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.train_data_dir, args_opt.schema_dir, @@ -190,25 +201,19 @@ def run_task_distill(ckpt_file): raise ValueError("Student ckpt file should not be None") cfg = phase2_cfg - if args_opt.device_target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) - elif args_opt.device_target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - else: - raise Exception("Target error, GPU or Ascend is supported.") - load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = ckpt_file netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, - is_training=True, task_type='classification', + is_training=True, task_type=args_opt.task_type, num_labels=task.num_labels, is_predistill=False) rank = 0 device_num = 1 train_dataset = create_tinybert_dataset('td', cfg.batch_size, device_num, rank, args_opt.do_shuffle, - args_opt.train_data_dir, args_opt.schema_dir) + args_opt.train_data_dir, args_opt.schema_dir, + data_type=dataset_type) dataset_size = train_dataset.get_dataset_size() print('td2 train dataset size: ', dataset_size) @@ -238,7 +243,8 @@ def run_task_distill(ckpt_file): eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, device_num, rank, args_opt.do_shuffle, - args_opt.eval_data_dir, args_opt.schema_dir) + args_opt.eval_data_dir, args_opt.schema_dir, + data_type=dataset_type) print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) if args_opt.do_eval.lower() == "true": @@ -263,6 +269,19 @@ def run_task_distill(ckpt_file): dataset_sink_mode=(args_opt.enable_data_sink == 'true'), sink_size=args_opt.data_sink_steps) +def eval_result_print(assessment_method="accuracy", callback=None): + """print eval result""" + if assessment_method == "accuracy": + print("============== acc is {}".format(callback.acc_num / callback.total_num)) + elif assessment_method == "bf1": + print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) + print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) + print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) + elif assessment_method == "mf1": + print("F1 {:.6f} ".format(callback.eval())) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1]") + def do_eval_standalone(): """ do eval standalone @@ -270,13 +289,12 @@ def do_eval_standalone(): ckpt_file = args_opt.load_td1_ckpt_path if ckpt_file == '': raise ValueError("Student ckpt file should not be None") - if args_opt.device_target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) - elif args_opt.device_target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + if args_opt.task_type == "classification": + eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") + elif args_opt.task_type == "ner": + eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") else: - raise Exception("Target error, GPU or Ascend is supported.") - eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") + raise ValueError(f"Not support the task type {args_opt.task_type}") param_dict = load_checkpoint(ckpt_file) new_param_dict = {} for key, value in param_dict.items(): @@ -289,11 +307,18 @@ def do_eval_standalone(): eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size, device_num=1, rank=0, do_shuffle="false", data_dir=args_opt.eval_data_dir, - schema_dir=args_opt.schema_dir) + schema_dir=args_opt.schema_dir, + data_type=dataset_type) print('eval dataset size: ', eval_dataset.get_dataset_size()) print('eval dataset batch size: ', eval_dataset.get_batch_size()) - - callback = Accuracy() + if args_opt.assessment_method == "accuracy": + callback = Accuracy() + elif args_opt.assessment_method == "bf1": + callback = F1(num_labels=task.num_labels) + elif args_opt.assessment_method == "mf1": + callback = F1(num_labels=task.num_labels, mode="MultiLabel") + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1]") columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in eval_dataset.create_dict_iterator(num_epochs=1): input_data = [] @@ -302,16 +327,16 @@ def do_eval_standalone(): input_ids, input_mask, token_type_id, label_ids = input_data logits = eval_model(input_ids, token_type_id, input_mask) callback.update(logits, label_ids) - acc = callback.acc_num / callback.total_num - print("======================================") - print("============== acc is {}".format(acc)) - print("======================================") + print("==============================================================") + eval_result_print(args_opt.assessment_method, callback) + print("==============================================================") if __name__ == '__main__': - if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": - raise ValueError("do_train or do eval must have one be true, please confirm your config") - + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, + reserve_class_name_in_scope=False) + if args_opt.device_target == "Ascend": + context.set_context(device_id=args_opt.device_id) enable_loss_scale = True if args_opt.device_target == "GPU": if td_student_net_cfg.compute_type != mstype.float32: @@ -321,6 +346,14 @@ if __name__ == '__main__': # and the loss scale is not necessary enable_loss_scale = False + if args_opt.device_target == "CPU": + logger.warning('CPU only support float32 temporarily, run with float32.') + td_teacher_net_cfg.dtype = mstype.float32 + td_teacher_net_cfg.compute_type = mstype.float32 + td_student_net_cfg.dtype = mstype.float32 + td_student_net_cfg.compute_type = mstype.float32 + enable_loss_scale = False + td_teacher_net_cfg.seq_length = task.seq_length td_student_net_cfg.seq_length = task.seq_length diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh index 2a6dbe7885..900ed9473d 100644 --- a/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh @@ -32,7 +32,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \ --do_eval="true" \ --td_phase1_epoch_size=10 \ --td_phase2_epoch_size=3 \ - --task_name="" \ --do_shuffle="true" \ --enable_data_sink="true" \ --data_sink_steps=100 \ @@ -44,5 +43,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \ --train_data_dir="" \ --eval_data_dir="" \ --schema_dir="" \ - --dataset_type="tfrecord" > log.txt 2>&1 & - + --dataset_type="tfrecord" \ + --task_type="classification" \ + --task_name="" \ + --assessment_method="accuracy" > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/tinybert/src/assessment_method.py b/model_zoo/official/nlp/tinybert/src/assessment_method.py index 748666e3ce..6ee2f64a8c 100644 --- a/model_zoo/official/nlp/tinybert/src/assessment_method.py +++ b/model_zoo/official/nlp/tinybert/src/assessment_method.py @@ -32,23 +32,56 @@ class Accuracy(): self.total_num += len(labels) class F1(): - """F1""" - def __init__(self): + ''' + calculate F1 score + ''' + def __init__(self, num_labels=2, mode="Binary"): self.TP = 0 self.FP = 0 self.FN = 0 + self.num_labels = num_labels + self.P = 0 + self.AP = 0 + self.mode = mode + if self.mode.lower() not in ("binary", "multilabel"): + raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]") def update(self, logits, labels): - """Update F1 score""" + ''' + update F1 score + ''' labels = labels.asnumpy() labels = np.reshape(labels, -1) logits = logits.asnumpy() logit_id = np.argmax(logits, axis=-1) logit_id = np.reshape(logit_id, -1) - pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7]) - pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7]) - self.TP += np.sum(pos_eva & pos_label) - self.FP += np.sum(pos_eva & (~pos_label)) - self.FN += np.sum((~pos_eva) & pos_label) - print("-----------------precision is ", self.TP / (self.TP + self.FP)) - print("-----------------recall is ", self.TP / (self.TP + self.FN)) + + if self.mode.lower() == "binary": + pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) + pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) + self.TP += np.sum(pos_eva&pos_label) + self.FP += np.sum(pos_eva&(~pos_label)) + self.FN += np.sum((~pos_eva)&pos_label) + else: + target = np.zeros((len(labels), self.num_labels), dtype=np.int) + pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int) + for i, label in enumerate(labels): + target[i][label] = 1 + for i, label in enumerate(logit_id): + pred[i][label] = 1 + positives = pred.sum(axis=0) + actual_positives = target.sum(axis=0) + true_positives = (target * pred).sum(axis=0) + self.TP += true_positives + self.P += positives + self.AP += actual_positives + + def eval(self): + if self.mode.lower() == "binary": + f1 = self.TP / (2 * self.TP + self.FP + self.FN) + else: + tp = np.sum(self.TP) + p = np.sum(self.P) + ap = np.sum(self.AP) + f1 = 2 * tp / (ap + p) + return f1 diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py index 6a90a271f3..3b1468fd41 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -28,7 +28,7 @@ from mindspore.communication.management import get_group_size from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net -from .tinybert_model import BertModel, TinyBertModel, BertModelCLS +from .tinybert_model import BertModel, TinyBertModel, BertModelCLS, BertModelNER GRADIENT_CLIP_TYPE = 1 @@ -362,8 +362,18 @@ class BertNetworkWithLoss_td(nn.Cell): temperature=1.0, dropout_prob=0.1): super(BertNetworkWithLoss_td, self).__init__() # load teacher model - self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, - use_one_hot_embeddings, "teacher") + if task_type == "classification": + self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, + use_one_hot_embeddings, "teacher") + self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, + use_one_hot_embeddings, "student") + elif task_type == "ner": + self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob, + use_one_hot_embeddings, "teacher") + self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob, + use_one_hot_embeddings, "student") + else: + raise ValueError(f"Not support task type: {task_type}") param_dict = load_checkpoint(teacher_ckpt) new_param_dict = {} for key, value in param_dict.items(): @@ -377,8 +387,6 @@ class BertNetworkWithLoss_td(nn.Cell): for param in params: param.requires_grad = False # load student model - self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, - use_one_hot_embeddings, "student") param_dict = load_checkpoint(student_ckpt) if is_predistill: new_param_dict = {} @@ -401,7 +409,7 @@ class BertNetworkWithLoss_td(nn.Cell): self.is_predistill = is_predistill self.is_att_fit = is_att_fit self.is_rep_fit = is_rep_fit - self.task_type = task_type + self.use_soft_cross_entropy = task_type in ["classification", "ner"] self.temperature = temperature self.loss_mse = nn.MSELoss() self.select = P.Select() @@ -456,7 +464,7 @@ class BertNetworkWithLoss_td(nn.Cell): rep_loss += self.loss_mse(student_rep, teacher_rep) total_loss += rep_loss else: - if self.task_type == "classification": + if self.use_soft_cross_entropy: cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) else: cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py index 0c0e2cbabf..47891934b0 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_model.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -926,3 +926,40 @@ class BertModelCLS(nn.Cell): if self._phase == 'train' or self.phase_type == "teacher": return seq_output, att_output, logits, log_probs return log_probs + +class BertModelNER(nn.Cell): + """ + This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). + The returned output represents the final logits as the results of log_softmax is proportional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0, + use_one_hot_embeddings=False, phase_type="student"): + super(BertModelNER, self).__init__() + if not is_training: + config.hidden_dropout_prob = 0.0 + config.hidden_probs_dropout_prob = 0.0 + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.ReLU() + self.reshape = P.Reshape() + self.shape = (-1, config.hidden_size) + self.origin_shape = (-1, config.seq_length, self.num_labels) + + def construct(self, input_ids, input_mask, token_type_id): + """Return the final logits as the results of log_softmax.""" + sequence_output, _, _, encoder_outputs, attention_outputs = \ + self.bert(input_ids, token_type_id, input_mask) + seq = self.dropout(sequence_output) + seq = self.reshape(seq, self.shape) + logits = self.dense_1(seq) + logits = self.cast(logits, self.dtype) + return_value = self.log_softmax(logits) + if self._phase == 'train' or self.phase_type == "teacher": + return encoder_outputs, attention_outputs, logits, return_value + return return_value