|
|
|
@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from src.dataset import create_tinybert_dataset, DataType
|
|
|
|
|
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
|
|
|
|
from src.assessment_method import Accuracy
|
|
|
|
|
from src.assessment_method import Accuracy, F1
|
|
|
|
|
from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
|
|
|
|
|
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
|
|
|
|
|
from src.tinybert_model import BertModelCLS
|
|
|
|
|
from src.tinybert_model import BertModelCLS, BertModelNER
|
|
|
|
|
|
|
|
|
|
_cur_dir = os.getcwd()
|
|
|
|
|
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
|
|
|
|
@ -46,7 +46,7 @@ def parse_args():
|
|
|
|
|
parse args
|
|
|
|
|
"""
|
|
|
|
|
parser = argparse.ArgumentParser(description='tinybert task distill')
|
|
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
|
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
|
|
|
|
help='device where the code will be implemented. (Default: Ascend)')
|
|
|
|
|
parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"],
|
|
|
|
|
help="Do train task, default is true.")
|
|
|
|
@ -69,21 +69,46 @@ def parse_args():
|
|
|
|
|
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
|
|
|
|
|
parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"],
|
|
|
|
|
help="The type of the task to train.")
|
|
|
|
|
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"],
|
|
|
|
|
help="The name of the task to train.")
|
|
|
|
|
parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"],
|
|
|
|
|
help="assessment_method include: [accuracy, bf1, mf1], default is accuracy")
|
|
|
|
|
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
|
|
|
|
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
if args.do_train.lower() != "true" and args.do_eval.lower() != "true":
|
|
|
|
|
raise ValueError("do train or do eval must have one be true, please confirm your config")
|
|
|
|
|
if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification":
|
|
|
|
|
raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification")
|
|
|
|
|
if args.task_name in ["CLUENER"] and args.task_type != "ner":
|
|
|
|
|
raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner")
|
|
|
|
|
if args.task_name in ["SST-2", "QNLI", "MNLI"] and \
|
|
|
|
|
(td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522):
|
|
|
|
|
logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\
|
|
|
|
|
"EN vocabs according to the origin paper.")
|
|
|
|
|
if args.task_name in ["TNEWS", "CLUENER"] and \
|
|
|
|
|
(td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128):
|
|
|
|
|
logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \
|
|
|
|
|
"EN vocabs according to the origin paper.")
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args_opt = parse_args()
|
|
|
|
|
|
|
|
|
|
if args_opt.dataset_type == "tfrecord":
|
|
|
|
|
dataset_type = DataType.TFRECORD
|
|
|
|
|
elif args_opt.dataset_type == "mindrecord":
|
|
|
|
|
dataset_type = DataType.MINDRECORD
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("dataset format is not supported yet")
|
|
|
|
|
DEFAULT_NUM_LABELS = 2
|
|
|
|
|
DEFAULT_SEQ_LENGTH = 128
|
|
|
|
|
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
|
|
|
|
"QNLI": {"num_labels": 2, "seq_length": 128},
|
|
|
|
|
"MNLI": {"num_labels": 3, "seq_length": 128}}
|
|
|
|
|
"MNLI": {"num_labels": 3, "seq_length": 128},
|
|
|
|
|
"TNEWS": {"num_labels": 15, "seq_length": 128},
|
|
|
|
|
"CLUENER": {"num_labels": 43, "seq_length": 128}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Task:
|
|
|
|
@ -112,29 +137,15 @@ def run_predistill():
|
|
|
|
|
run predistill
|
|
|
|
|
"""
|
|
|
|
|
cfg = phase1_cfg
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
elif args_opt.device_target == "GPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Target error, GPU or Ascend is supported.")
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
|
|
|
|
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
|
|
|
|
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
|
|
|
|
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
|
|
|
|
is_training=True, task_type='classification',
|
|
|
|
|
is_training=True, task_type=args_opt.task_type,
|
|
|
|
|
num_labels=task.num_labels, is_predistill=True)
|
|
|
|
|
|
|
|
|
|
rank = 0
|
|
|
|
|
device_num = 1
|
|
|
|
|
|
|
|
|
|
if args_opt.dataset_type == "tfrecord":
|
|
|
|
|
dataset_type = DataType.TFRECORD
|
|
|
|
|
elif args_opt.dataset_type == "mindrecord":
|
|
|
|
|
dataset_type = DataType.MINDRECORD
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("dataset format is not supported yet")
|
|
|
|
|
dataset = create_tinybert_dataset('td', cfg.batch_size,
|
|
|
|
|
device_num, rank, args_opt.do_shuffle,
|
|
|
|
|
args_opt.train_data_dir, args_opt.schema_dir,
|
|
|
|
@ -190,25 +201,19 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
raise ValueError("Student ckpt file should not be None")
|
|
|
|
|
cfg = phase2_cfg
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
elif args_opt.device_target == "GPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Target error, GPU or Ascend is supported.")
|
|
|
|
|
|
|
|
|
|
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
|
|
|
|
|
load_student_checkpoint_path = ckpt_file
|
|
|
|
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
|
|
|
|
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
|
|
|
|
is_training=True, task_type='classification',
|
|
|
|
|
is_training=True, task_type=args_opt.task_type,
|
|
|
|
|
num_labels=task.num_labels, is_predistill=False)
|
|
|
|
|
|
|
|
|
|
rank = 0
|
|
|
|
|
device_num = 1
|
|
|
|
|
train_dataset = create_tinybert_dataset('td', cfg.batch_size,
|
|
|
|
|
device_num, rank, args_opt.do_shuffle,
|
|
|
|
|
args_opt.train_data_dir, args_opt.schema_dir)
|
|
|
|
|
args_opt.train_data_dir, args_opt.schema_dir,
|
|
|
|
|
data_type=dataset_type)
|
|
|
|
|
|
|
|
|
|
dataset_size = train_dataset.get_dataset_size()
|
|
|
|
|
print('td2 train dataset size: ', dataset_size)
|
|
|
|
@ -238,7 +243,8 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
|
|
|
|
|
eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size,
|
|
|
|
|
device_num, rank, args_opt.do_shuffle,
|
|
|
|
|
args_opt.eval_data_dir, args_opt.schema_dir)
|
|
|
|
|
args_opt.eval_data_dir, args_opt.schema_dir,
|
|
|
|
|
data_type=dataset_type)
|
|
|
|
|
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
|
|
|
|
|
|
|
|
|
|
if args_opt.do_eval.lower() == "true":
|
|
|
|
@ -263,6 +269,19 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
|
|
|
|
sink_size=args_opt.data_sink_steps)
|
|
|
|
|
|
|
|
|
|
def eval_result_print(assessment_method="accuracy", callback=None):
|
|
|
|
|
"""print eval result"""
|
|
|
|
|
if assessment_method == "accuracy":
|
|
|
|
|
print("============== acc is {}".format(callback.acc_num / callback.total_num))
|
|
|
|
|
elif assessment_method == "bf1":
|
|
|
|
|
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
|
|
|
|
|
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
|
|
|
|
|
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
|
|
|
|
|
elif assessment_method == "mf1":
|
|
|
|
|
print("F1 {:.6f} ".format(callback.eval()))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
|
|
|
|
|
|
|
|
|
|
def do_eval_standalone():
|
|
|
|
|
"""
|
|
|
|
|
do eval standalone
|
|
|
|
@ -270,13 +289,12 @@ def do_eval_standalone():
|
|
|
|
|
ckpt_file = args_opt.load_td1_ckpt_path
|
|
|
|
|
if ckpt_file == '':
|
|
|
|
|
raise ValueError("Student ckpt file should not be None")
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
elif args_opt.device_target == "GPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
|
|
|
|
if args_opt.task_type == "classification":
|
|
|
|
|
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
|
|
|
|
elif args_opt.task_type == "ner":
|
|
|
|
|
eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Target error, GPU or Ascend is supported.")
|
|
|
|
|
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
|
|
|
|
raise ValueError(f"Not support the task type {args_opt.task_type}")
|
|
|
|
|
param_dict = load_checkpoint(ckpt_file)
|
|
|
|
|
new_param_dict = {}
|
|
|
|
|
for key, value in param_dict.items():
|
|
|
|
@ -289,11 +307,18 @@ def do_eval_standalone():
|
|
|
|
|
eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size,
|
|
|
|
|
device_num=1, rank=0, do_shuffle="false",
|
|
|
|
|
data_dir=args_opt.eval_data_dir,
|
|
|
|
|
schema_dir=args_opt.schema_dir)
|
|
|
|
|
schema_dir=args_opt.schema_dir,
|
|
|
|
|
data_type=dataset_type)
|
|
|
|
|
print('eval dataset size: ', eval_dataset.get_dataset_size())
|
|
|
|
|
print('eval dataset batch size: ', eval_dataset.get_batch_size())
|
|
|
|
|
|
|
|
|
|
callback = Accuracy()
|
|
|
|
|
if args_opt.assessment_method == "accuracy":
|
|
|
|
|
callback = Accuracy()
|
|
|
|
|
elif args_opt.assessment_method == "bf1":
|
|
|
|
|
callback = F1(num_labels=task.num_labels)
|
|
|
|
|
elif args_opt.assessment_method == "mf1":
|
|
|
|
|
callback = F1(num_labels=task.num_labels, mode="MultiLabel")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Assessment method not supported, support: [accuracy, f1]")
|
|
|
|
|
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
|
|
|
|
for data in eval_dataset.create_dict_iterator(num_epochs=1):
|
|
|
|
|
input_data = []
|
|
|
|
@ -302,16 +327,16 @@ def do_eval_standalone():
|
|
|
|
|
input_ids, input_mask, token_type_id, label_ids = input_data
|
|
|
|
|
logits = eval_model(input_ids, token_type_id, input_mask)
|
|
|
|
|
callback.update(logits, label_ids)
|
|
|
|
|
acc = callback.acc_num / callback.total_num
|
|
|
|
|
print("======================================")
|
|
|
|
|
print("============== acc is {}".format(acc))
|
|
|
|
|
print("======================================")
|
|
|
|
|
print("==============================================================")
|
|
|
|
|
eval_result_print(args_opt.assessment_method, callback)
|
|
|
|
|
print("==============================================================")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
|
|
|
|
raise ValueError("do_train or do eval must have one be true, please confirm your config")
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
|
|
|
|
reserve_class_name_in_scope=False)
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
|
context.set_context(device_id=args_opt.device_id)
|
|
|
|
|
enable_loss_scale = True
|
|
|
|
|
if args_opt.device_target == "GPU":
|
|
|
|
|
if td_student_net_cfg.compute_type != mstype.float32:
|
|
|
|
@ -321,6 +346,14 @@ if __name__ == '__main__':
|
|
|
|
|
# and the loss scale is not necessary
|
|
|
|
|
enable_loss_scale = False
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == "CPU":
|
|
|
|
|
logger.warning('CPU only support float32 temporarily, run with float32.')
|
|
|
|
|
td_teacher_net_cfg.dtype = mstype.float32
|
|
|
|
|
td_teacher_net_cfg.compute_type = mstype.float32
|
|
|
|
|
td_student_net_cfg.dtype = mstype.float32
|
|
|
|
|
td_student_net_cfg.compute_type = mstype.float32
|
|
|
|
|
enable_loss_scale = False
|
|
|
|
|
|
|
|
|
|
td_teacher_net_cfg.seq_length = task.seq_length
|
|
|
|
|
td_student_net_cfg.seq_length = task.seq_length
|
|
|
|
|
|
|
|
|
|