Support CPU tinybert and ner task

pull/13903/head
zhaoting 4 years ago
parent 5d0490909d
commit c14c707475

@ -1,4 +1,4 @@
# Contents # Contents
- [Contents](#contents) - [Contents](#contents)
- [TinyBERT Description](#tinybert-description) - [TinyBERT Description](#tinybert-description)
@ -197,8 +197,9 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN
[--load_gd_ckpt_path LOAD_GD_CKPT_PATH] [--load_gd_ckpt_path LOAD_GD_CKPT_PATH]
[--load_td1_ckpt_path LOAD_TD1_CKPT_PATH] [--load_td1_ckpt_path LOAD_TD1_CKPT_PATH]
[--train_data_dir TRAIN_DATA_DIR] [--train_data_dir TRAIN_DATA_DIR]
[--eval_data_dir EVAL_DATA_DIR] [--eval_data_dir EVAL_DATA_DIR] [--task_type TASK_TYPE]
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] [--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE]
[--assessment_method ASSESSMENT_METHOD]
options: options:
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
@ -217,7 +218,9 @@ options:
--load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is "" --load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is ""
--train_data_dir path to train dataset directory: PATH, default is "" --train_data_dir path to train dataset directory: PATH, default is ""
--eval_data_dir path to eval dataset directory: PATH, default is "" --eval_data_dir path to eval dataset directory: PATH, default is ""
--task_name classification task: "SST-2" | "QNLI" | "MNLI", default is "" --task_type task type: "classification" | "ner", default is "classification"
--task_name classification or ner task: "SST-2" | "QNLI" | "MNLI" | "TNEWS", "CLUENER", default is ""
--assessment_method assessment method to do evaluation: acc | f1
--schema_dir path to schema.json file, PATH, default is "" --schema_dir path to schema.json file, PATH, default is ""
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord --dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
``` ```
@ -249,6 +252,7 @@ Parameters for optimizer:
Parameters for bert network: Parameters for bert network:
seq_length length of input sequence: N, default is 128 seq_length length of input sequence: N, default is 128
vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522 vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522
Usually, we use 21128 for CN vocabs and 30522 for EN vocabs according to the origin paper. Default is 30522
hidden_size size of bert encoder layers: N hidden_size size of bert encoder layers: N
num_hidden_layers number of hidden layers: N num_hidden_layers number of hidden layers: N
num_attention_heads number of attention heads: N, default is 12 num_attention_heads number of attention heads: N, default is 12

@ -22,7 +22,7 @@ from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.td_config import td_student_net_cfg from src.td_config import td_student_net_cfg
from src.tinybert_model import BertModelCLS from src.tinybert_model import BertModelCLS, BertModelNER
parser = argparse.ArgumentParser(description='tinybert task distill') parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--device_id", type=int, default=0, help="Device id")
@ -31,7 +31,10 @@ parser.add_argument("--file_name", type=str, default="tinybert", help="output fi
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend", parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
help="The type of the task to train.")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
help="The name of the task to train.")
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
@ -43,7 +46,9 @@ DEFAULT_SEQ_LENGTH = 128
DEFAULT_BS = 32 DEFAULT_BS = 32
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128}, "QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}} "MNLI": {"num_labels": 3, "seq_length": 128},
"TNEWS": {"num_labels": 15, "seq_length": 128},
"CLUENER": {"num_labels": 10, "seq_length": 128}}
class Task: class Task:
""" """
@ -68,8 +73,13 @@ if __name__ == '__main__':
task = Task(args.task_name) task = Task(args.task_name)
td_student_net_cfg.seq_length = task.seq_length td_student_net_cfg.seq_length = task.seq_length
td_student_net_cfg.batch_size = DEFAULT_BS td_student_net_cfg.batch_size = DEFAULT_BS
if args.task_type == "classification":
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
elif args.task_type == "ner":
eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
else:
raise ValueError(f"Not support task type: {args.task_type}")
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(args.ckpt_file)
new_param_dict = {} new_param_dict = {}
for key, value in param_dict.items(): for key, value in param_dict.items():

@ -33,14 +33,10 @@ from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
def get_argument():
"""Tinybert general distill argument parser."""
def run_general_distill():
"""
run general distill
"""
parser = argparse.ArgumentParser(description='tinybert general distill') parser = argparse.ArgumentParser(description='tinybert general distill')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)') help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.") help="Run distribute, default is false.")
@ -61,20 +57,21 @@ def run_general_distill():
parser.add_argument("--dataset_type", type=str, default="tfrecord", parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord") help="dataset type tfrecord/mindrecord, default is tfrecord")
args_opt = parser.parse_args() args_opt = parser.parse_args()
return args_opt
def run_general_distill():
"""
run general distill
"""
args_opt = get_argument()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
reserve_class_name_in_scope=False)
if args_opt.device_target == "Ascend": if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
context.set_context(reserve_class_name_in_scope=False)
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if args_opt.distribute == "true": if args_opt.distribute == "true":
if args_opt.device_target == 'Ascend': if args_opt.device_target == 'Ascend':
D.init() D.init()
@ -104,6 +101,14 @@ def run_general_distill():
# and the loss scale is not necessary # and the loss scale is not necessary
enable_loss_scale = False enable_loss_scale = False
if args_opt.device_target == "CPU":
logger.warning('CPU only support float32 temporarily, run with float32.')
bert_teacher_net_cfg.dtype = mstype.float32
bert_teacher_net_cfg.compute_type = mstype.float32
bert_student_net_cfg.dtype = mstype.float32
bert_student_net_cfg.compute_type = mstype.float32
enable_loss_scale = False
netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
teacher_ckpt=args_opt.load_teacher_ckpt_path, teacher_ckpt=args_opt.load_teacher_ckpt_path,
student_config=bert_student_net_cfg, student_config=bert_student_net_cfg,

@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay
from mindspore import log as logger from mindspore import log as logger
from src.dataset import create_tinybert_dataset, DataType from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy from src.assessment_method import Accuracy, F1
from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
from src.tinybert_model import BertModelCLS from src.tinybert_model import BertModelCLS, BertModelNER
_cur_dir = os.getcwd() _cur_dir = os.getcwd()
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
@ -46,7 +46,7 @@ def parse_args():
parse args parse args
""" """
parser = argparse.ArgumentParser(description='tinybert task distill') parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)') help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"], parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"],
help="Do train task, default is true.") help="Do train task, default is true.")
@ -69,21 +69,46 @@ def parse_args():
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
help="The type of the task to train.")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
help="The name of the task to train.") help="The name of the task to train.")
parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"],
help="assessment_method include: [accuracy, bf1, mf1], default is accuracy")
parser.add_argument("--dataset_type", type=str, default="tfrecord", parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord") help="dataset type tfrecord/mindrecord, default is tfrecord")
args = parser.parse_args() args = parser.parse_args()
if args.do_train.lower() != "true" and args.do_eval.lower() != "true":
raise ValueError("do train or do eval must have one be true, please confirm your config")
if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification":
raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification")
if args.task_name in ["CLUENER"] and args.task_type != "ner":
raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner")
if args.task_name in ["SST-2", "QNLI", "MNLI"] and \
(td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522):
logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\
"EN vocabs according to the origin paper.")
if args.task_name in ["TNEWS", "CLUENER"] and \
(td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128):
logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \
"EN vocabs according to the origin paper.")
return args return args
args_opt = parse_args() args_opt = parse_args()
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif args_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
DEFAULT_NUM_LABELS = 2 DEFAULT_NUM_LABELS = 2
DEFAULT_SEQ_LENGTH = 128 DEFAULT_SEQ_LENGTH = 128
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128}, "QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}} "MNLI": {"num_labels": 3, "seq_length": 128},
"TNEWS": {"num_labels": 15, "seq_length": 128},
"CLUENER": {"num_labels": 43, "seq_length": 128}}
class Task: class Task:
@ -112,29 +137,15 @@ def run_predistill():
run predistill run predistill
""" """
cfg = phase1_cfg cfg = phase1_cfg
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
context.set_context(reserve_class_name_in_scope=False)
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = args_opt.load_gd_ckpt_path load_student_checkpoint_path = args_opt.load_gd_ckpt_path
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification', is_training=True, task_type=args_opt.task_type,
num_labels=task.num_labels, is_predistill=True) num_labels=task.num_labels, is_predistill=True)
rank = 0 rank = 0
device_num = 1 device_num = 1
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif args_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('td', cfg.batch_size, dataset = create_tinybert_dataset('td', cfg.batch_size,
device_num, rank, args_opt.do_shuffle, device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir, args_opt.train_data_dir, args_opt.schema_dir,
@ -190,25 +201,19 @@ def run_task_distill(ckpt_file):
raise ValueError("Student ckpt file should not be None") raise ValueError("Student ckpt file should not be None")
cfg = phase2_cfg cfg = phase2_cfg
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = ckpt_file load_student_checkpoint_path = ckpt_file
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification', is_training=True, task_type=args_opt.task_type,
num_labels=task.num_labels, is_predistill=False) num_labels=task.num_labels, is_predistill=False)
rank = 0 rank = 0
device_num = 1 device_num = 1
train_dataset = create_tinybert_dataset('td', cfg.batch_size, train_dataset = create_tinybert_dataset('td', cfg.batch_size,
device_num, rank, args_opt.do_shuffle, device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir) args_opt.train_data_dir, args_opt.schema_dir,
data_type=dataset_type)
dataset_size = train_dataset.get_dataset_size() dataset_size = train_dataset.get_dataset_size()
print('td2 train dataset size: ', dataset_size) print('td2 train dataset size: ', dataset_size)
@ -238,7 +243,8 @@ def run_task_distill(ckpt_file):
eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size,
device_num, rank, args_opt.do_shuffle, device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir) args_opt.eval_data_dir, args_opt.schema_dir,
data_type=dataset_type)
print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
@ -263,6 +269,19 @@ def run_task_distill(ckpt_file):
dataset_sink_mode=(args_opt.enable_data_sink == 'true'), dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
sink_size=args_opt.data_sink_steps) sink_size=args_opt.data_sink_steps)
def eval_result_print(assessment_method="accuracy", callback=None):
"""print eval result"""
if assessment_method == "accuracy":
print("============== acc is {}".format(callback.acc_num / callback.total_num))
elif assessment_method == "bf1":
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
elif assessment_method == "mf1":
print("F1 {:.6f} ".format(callback.eval()))
else:
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
def do_eval_standalone(): def do_eval_standalone():
""" """
do eval standalone do eval standalone
@ -270,13 +289,12 @@ def do_eval_standalone():
ckpt_file = args_opt.load_td1_ckpt_path ckpt_file = args_opt.load_td1_ckpt_path
if ckpt_file == '': if ckpt_file == '':
raise ValueError("Student ckpt file should not be None") raise ValueError("Student ckpt file should not be None")
if args_opt.device_target == "Ascend": if args_opt.task_type == "classification":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
elif args_opt.device_target == "GPU": elif args_opt.task_type == "ner":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
else: else:
raise Exception("Target error, GPU or Ascend is supported.") raise ValueError(f"Not support the task type {args_opt.task_type}")
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(ckpt_file) param_dict = load_checkpoint(ckpt_file)
new_param_dict = {} new_param_dict = {}
for key, value in param_dict.items(): for key, value in param_dict.items():
@ -289,11 +307,18 @@ def do_eval_standalone():
eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size, eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size,
device_num=1, rank=0, do_shuffle="false", device_num=1, rank=0, do_shuffle="false",
data_dir=args_opt.eval_data_dir, data_dir=args_opt.eval_data_dir,
schema_dir=args_opt.schema_dir) schema_dir=args_opt.schema_dir,
data_type=dataset_type)
print('eval dataset size: ', eval_dataset.get_dataset_size()) print('eval dataset size: ', eval_dataset.get_dataset_size())
print('eval dataset batch size: ', eval_dataset.get_batch_size()) print('eval dataset batch size: ', eval_dataset.get_batch_size())
if args_opt.assessment_method == "accuracy":
callback = Accuracy() callback = Accuracy()
elif args_opt.assessment_method == "bf1":
callback = F1(num_labels=task.num_labels)
elif args_opt.assessment_method == "mf1":
callback = F1(num_labels=task.num_labels, mode="MultiLabel")
else:
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in eval_dataset.create_dict_iterator(num_epochs=1): for data in eval_dataset.create_dict_iterator(num_epochs=1):
input_data = [] input_data = []
@ -302,16 +327,16 @@ def do_eval_standalone():
input_ids, input_mask, token_type_id, label_ids = input_data input_ids, input_mask, token_type_id, label_ids = input_data
logits = eval_model(input_ids, token_type_id, input_mask) logits = eval_model(input_ids, token_type_id, input_mask)
callback.update(logits, label_ids) callback.update(logits, label_ids)
acc = callback.acc_num / callback.total_num print("==============================================================")
print("======================================") eval_result_print(args_opt.assessment_method, callback)
print("============== acc is {}".format(acc)) print("==============================================================")
print("======================================")
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
raise ValueError("do_train or do eval must have one be true, please confirm your config") reserve_class_name_in_scope=False)
if args_opt.device_target == "Ascend":
context.set_context(device_id=args_opt.device_id)
enable_loss_scale = True enable_loss_scale = True
if args_opt.device_target == "GPU": if args_opt.device_target == "GPU":
if td_student_net_cfg.compute_type != mstype.float32: if td_student_net_cfg.compute_type != mstype.float32:
@ -321,6 +346,14 @@ if __name__ == '__main__':
# and the loss scale is not necessary # and the loss scale is not necessary
enable_loss_scale = False enable_loss_scale = False
if args_opt.device_target == "CPU":
logger.warning('CPU only support float32 temporarily, run with float32.')
td_teacher_net_cfg.dtype = mstype.float32
td_teacher_net_cfg.compute_type = mstype.float32
td_student_net_cfg.dtype = mstype.float32
td_student_net_cfg.compute_type = mstype.float32
enable_loss_scale = False
td_teacher_net_cfg.seq_length = task.seq_length td_teacher_net_cfg.seq_length = task.seq_length
td_student_net_cfg.seq_length = task.seq_length td_student_net_cfg.seq_length = task.seq_length

@ -32,7 +32,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--do_eval="true" \ --do_eval="true" \
--td_phase1_epoch_size=10 \ --td_phase1_epoch_size=10 \
--td_phase2_epoch_size=3 \ --td_phase2_epoch_size=3 \
--task_name="" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="true" \ --enable_data_sink="true" \
--data_sink_steps=100 \ --data_sink_steps=100 \
@ -44,5 +43,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--train_data_dir="" \ --train_data_dir="" \
--eval_data_dir="" \ --eval_data_dir="" \
--schema_dir="" \ --schema_dir="" \
--dataset_type="tfrecord" > log.txt 2>&1 & --dataset_type="tfrecord" \
--task_type="classification" \
--task_name="" \
--assessment_method="accuracy" > log.txt 2>&1 &

@ -32,23 +32,56 @@ class Accuracy():
self.total_num += len(labels) self.total_num += len(labels)
class F1(): class F1():
"""F1""" '''
def __init__(self): calculate F1 score
'''
def __init__(self, num_labels=2, mode="Binary"):
self.TP = 0 self.TP = 0
self.FP = 0 self.FP = 0
self.FN = 0 self.FN = 0
self.num_labels = num_labels
self.P = 0
self.AP = 0
self.mode = mode
if self.mode.lower() not in ("binary", "multilabel"):
raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]")
def update(self, logits, labels): def update(self, logits, labels):
"""Update F1 score""" '''
update F1 score
'''
labels = labels.asnumpy() labels = labels.asnumpy()
labels = np.reshape(labels, -1) labels = np.reshape(labels, -1)
logits = logits.asnumpy() logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1) logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1) logit_id = np.reshape(logit_id, -1)
pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7])
pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7]) if self.mode.lower() == "binary":
self.TP += np.sum(pos_eva & pos_label) pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
self.FP += np.sum(pos_eva & (~pos_label)) pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
self.FN += np.sum((~pos_eva) & pos_label) self.TP += np.sum(pos_eva&pos_label)
print("-----------------precision is ", self.TP / (self.TP + self.FP)) self.FP += np.sum(pos_eva&(~pos_label))
print("-----------------recall is ", self.TP / (self.TP + self.FN)) self.FN += np.sum((~pos_eva)&pos_label)
else:
target = np.zeros((len(labels), self.num_labels), dtype=np.int)
pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int)
for i, label in enumerate(labels):
target[i][label] = 1
for i, label in enumerate(logit_id):
pred[i][label] = 1
positives = pred.sum(axis=0)
actual_positives = target.sum(axis=0)
true_positives = (target * pred).sum(axis=0)
self.TP += true_positives
self.P += positives
self.AP += actual_positives
def eval(self):
if self.mode.lower() == "binary":
f1 = self.TP / (2 * self.TP + self.FP + self.FN)
else:
tp = np.sum(self.TP)
p = np.sum(self.P)
ap = np.sum(self.AP)
f1 = 2 * tp / (ap + p)
return f1

@ -28,7 +28,7 @@ from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS from .tinybert_model import BertModel, TinyBertModel, BertModelCLS, BertModelNER
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
@ -362,8 +362,18 @@ class BertNetworkWithLoss_td(nn.Cell):
temperature=1.0, dropout_prob=0.1): temperature=1.0, dropout_prob=0.1):
super(BertNetworkWithLoss_td, self).__init__() super(BertNetworkWithLoss_td, self).__init__()
# load teacher model # load teacher model
self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, if task_type == "classification":
use_one_hot_embeddings, "teacher") self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob,
use_one_hot_embeddings, "teacher")
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
use_one_hot_embeddings, "student")
elif task_type == "ner":
self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob,
use_one_hot_embeddings, "teacher")
self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob,
use_one_hot_embeddings, "student")
else:
raise ValueError(f"Not support task type: {task_type}")
param_dict = load_checkpoint(teacher_ckpt) param_dict = load_checkpoint(teacher_ckpt)
new_param_dict = {} new_param_dict = {}
for key, value in param_dict.items(): for key, value in param_dict.items():
@ -377,8 +387,6 @@ class BertNetworkWithLoss_td(nn.Cell):
for param in params: for param in params:
param.requires_grad = False param.requires_grad = False
# load student model # load student model
self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob,
use_one_hot_embeddings, "student")
param_dict = load_checkpoint(student_ckpt) param_dict = load_checkpoint(student_ckpt)
if is_predistill: if is_predistill:
new_param_dict = {} new_param_dict = {}
@ -401,7 +409,7 @@ class BertNetworkWithLoss_td(nn.Cell):
self.is_predistill = is_predistill self.is_predistill = is_predistill
self.is_att_fit = is_att_fit self.is_att_fit = is_att_fit
self.is_rep_fit = is_rep_fit self.is_rep_fit = is_rep_fit
self.task_type = task_type self.use_soft_cross_entropy = task_type in ["classification", "ner"]
self.temperature = temperature self.temperature = temperature
self.loss_mse = nn.MSELoss() self.loss_mse = nn.MSELoss()
self.select = P.Select() self.select = P.Select()
@ -456,7 +464,7 @@ class BertNetworkWithLoss_td(nn.Cell):
rep_loss += self.loss_mse(student_rep, teacher_rep) rep_loss += self.loss_mse(student_rep, teacher_rep)
total_loss += rep_loss total_loss += rep_loss
else: else:
if self.task_type == "classification": if self.use_soft_cross_entropy:
cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature)
else: else:
cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1])

@ -926,3 +926,40 @@ class BertModelCLS(nn.Cell):
if self._phase == 'train' or self.phase_type == "teacher": if self._phase == 'train' or self.phase_type == "teacher":
return seq_output, att_output, logits, log_probs return seq_output, att_output, logits, log_probs
return log_probs return log_probs
class BertModelNER(nn.Cell):
"""
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0,
use_one_hot_embeddings=False, phase_type="student"):
super(BertModelNER, self).__init__()
if not is_training:
config.hidden_dropout_prob = 0.0
config.hidden_probs_dropout_prob = 0.0
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.ReLU()
self.reshape = P.Reshape()
self.shape = (-1, config.hidden_size)
self.origin_shape = (-1, config.seq_length, self.num_labels)
def construct(self, input_ids, input_mask, token_type_id):
"""Return the final logits as the results of log_softmax."""
sequence_output, _, _, encoder_outputs, attention_outputs = \
self.bert(input_ids, token_type_id, input_mask)
seq = self.dropout(sequence_output)
seq = self.reshape(seq, self.shape)
logits = self.dense_1(seq)
logits = self.cast(logits, self.dtype)
return_value = self.log_softmax(logits)
if self._phase == 'train' or self.phase_type == "teacher":
return encoder_outputs, attention_outputs, logits, return_value
return return_value

Loading…
Cancel
Save