|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import argparse
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
|
|
|
from mindspore.nn.optim import AdamWeightDecay
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from src.dataset import create_tinybert_dataset
|
|
|
|
|
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
|
|
|
|
|
from src.assessment_method import Accuracy
|
|
|
|
|
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
|
|
|
|
|
from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td
|
|
|
|
|
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
|
|
|
|
|
from src.tinybert_model import BertModelCLS
|
|
|
|
|
|
|
|
|
|
_cur_dir = os.getcwd()
|
|
|
|
@ -45,14 +47,14 @@ def parse_args():
|
|
|
|
|
parse args
|
|
|
|
|
"""
|
|
|
|
|
parser = argparse.ArgumentParser(description='tinybert task distill')
|
|
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.")
|
|
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
|
|
|
|
help='device where the code will be implemented. (Default: Ascend)')
|
|
|
|
|
parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.")
|
|
|
|
|
parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.")
|
|
|
|
|
parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
|
|
|
|
|
help="Epoch size for td phase 1, default is 10.")
|
|
|
|
|
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
|
|
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
|
|
|
|
parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.")
|
|
|
|
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
|
|
|
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
|
|
|
|
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
|
|
|
|
@ -64,11 +66,43 @@ def parse_args():
|
|
|
|
|
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
|
|
|
|
|
help="The name of the task to train.")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args_opt = parse_args()
|
|
|
|
|
|
|
|
|
|
DEFAULT_NUM_LABELS = 2
|
|
|
|
|
DEFAULT_SEQ_LENGTH = 128
|
|
|
|
|
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
|
|
|
|
|
"QNLI": {"num_labels": 2, "seq_length": 128},
|
|
|
|
|
"MNLI": {"num_labels": 3, "seq_length": 128}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Task:
|
|
|
|
|
"""
|
|
|
|
|
Encapsulation class of get the task parameter.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, task_name):
|
|
|
|
|
self.task_name = task_name
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def num_labels(self):
|
|
|
|
|
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
|
|
|
|
|
return task_params[self.task_name]["num_labels"]
|
|
|
|
|
return DEFAULT_NUM_LABELS
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def seq_length(self):
|
|
|
|
|
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
|
|
|
|
|
return task_params[self.task_name]["seq_length"]
|
|
|
|
|
return DEFAULT_SEQ_LENGTH
|
|
|
|
|
task = Task(args_opt.task_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_predistill():
|
|
|
|
|
"""
|
|
|
|
|
run predistill
|
|
|
|
@ -81,7 +115,7 @@ def run_predistill():
|
|
|
|
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
|
|
|
|
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
|
|
|
|
is_training=True, task_type='classification',
|
|
|
|
|
num_labels=args_opt.num_labels, is_predistill=True)
|
|
|
|
|
num_labels=task.num_labels, is_predistill=True)
|
|
|
|
|
|
|
|
|
|
rank = 0
|
|
|
|
|
device_num = 1
|
|
|
|
@ -91,8 +125,9 @@ def run_predistill():
|
|
|
|
|
|
|
|
|
|
dataset_size = dataset.get_dataset_size()
|
|
|
|
|
print('td1 dataset size: ', dataset_size)
|
|
|
|
|
print('td1 dataset repeatcount: ', dataset.get_repeat_count())
|
|
|
|
|
if args_opt.enable_data_sink == 'true':
|
|
|
|
|
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
|
|
|
|
|
repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps
|
|
|
|
|
time_monitor_steps = args_opt.data_sink_steps
|
|
|
|
|
else:
|
|
|
|
|
repeat_count = args_opt.td_phase1_epoch_size
|
|
|
|
@ -117,10 +152,14 @@ def run_predistill():
|
|
|
|
|
args_opt.save_ckpt_step,
|
|
|
|
|
args_opt.max_ckpt_num,
|
|
|
|
|
td_phase1_save_ckpt_dir)]
|
|
|
|
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
|
|
|
|
scale_factor=cfg.scale_factor,
|
|
|
|
|
scale_window=cfg.scale_window)
|
|
|
|
|
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
|
|
|
|
if enable_loss_scale:
|
|
|
|
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
|
|
|
|
scale_factor=cfg.scale_factor,
|
|
|
|
|
scale_window=cfg.scale_window)
|
|
|
|
|
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
|
|
|
|
else:
|
|
|
|
|
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
|
|
|
|
|
|
|
|
|
|
model = Model(netwithgrads)
|
|
|
|
|
model.train(repeat_count, dataset, callbacks=callback,
|
|
|
|
|
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
|
|
|
@ -139,7 +178,7 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
|
|
|
|
|
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
|
|
|
|
|
is_training=True, task_type='classification',
|
|
|
|
|
num_labels=args_opt.num_labels, is_predistill=False)
|
|
|
|
|
num_labels=task.num_labels, is_predistill=False)
|
|
|
|
|
|
|
|
|
|
rank = 0
|
|
|
|
|
device_num = 1
|
|
|
|
@ -149,6 +188,7 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
|
|
|
|
|
dataset_size = train_dataset.get_dataset_size()
|
|
|
|
|
print('td2 train dataset size: ', dataset_size)
|
|
|
|
|
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
|
|
|
|
|
if args_opt.enable_data_sink == 'true':
|
|
|
|
|
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
|
|
|
|
|
time_monitor_steps = args_opt.data_sink_steps
|
|
|
|
@ -175,6 +215,7 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
|
|
|
|
|
device_num, rank, args_opt.do_shuffle,
|
|
|
|
|
args_opt.eval_data_dir, args_opt.schema_dir)
|
|
|
|
|
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
|
|
|
|
|
|
|
|
|
|
if args_opt.do_eval.lower() == "true":
|
|
|
|
|
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
|
|
|
|
@ -185,11 +226,14 @@ def run_task_distill(ckpt_file):
|
|
|
|
|
args_opt.save_ckpt_step,
|
|
|
|
|
args_opt.max_ckpt_num,
|
|
|
|
|
td_phase2_save_ckpt_dir)]
|
|
|
|
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
|
|
|
|
scale_factor=cfg.scale_factor,
|
|
|
|
|
scale_window=cfg.scale_window)
|
|
|
|
|
if enable_loss_scale:
|
|
|
|
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
|
|
|
|
scale_factor=cfg.scale_factor,
|
|
|
|
|
scale_window=cfg.scale_window)
|
|
|
|
|
|
|
|
|
|
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
|
|
|
|
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
|
|
|
|
else:
|
|
|
|
|
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
|
|
|
|
|
model = Model(netwithgrads)
|
|
|
|
|
model.train(repeat_count, train_dataset, callbacks=callback,
|
|
|
|
|
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
|
|
|
|
@ -203,7 +247,7 @@ def do_eval_standalone():
|
|
|
|
|
if ckpt_file == '':
|
|
|
|
|
raise ValueError("Student ckpt file should not be None")
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student")
|
|
|
|
|
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
|
|
|
|
|
param_dict = load_checkpoint(ckpt_file)
|
|
|
|
|
new_param_dict = {}
|
|
|
|
|
for key, value in param_dict.items():
|
|
|
|
@ -213,10 +257,13 @@ def do_eval_standalone():
|
|
|
|
|
load_param_into_net(eval_model, new_param_dict)
|
|
|
|
|
eval_model.set_train(False)
|
|
|
|
|
|
|
|
|
|
eval_dataset = create_tinybert_dataset('td', batch_size=1,
|
|
|
|
|
eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size,
|
|
|
|
|
device_num=1, rank=0, do_shuffle="false",
|
|
|
|
|
data_dir=args_opt.eval_data_dir,
|
|
|
|
|
schema_dir=args_opt.schema_dir)
|
|
|
|
|
print('eval dataset size: ', eval_dataset.get_dataset_size())
|
|
|
|
|
print('eval dataset batch size: ', eval_dataset.get_batch_size())
|
|
|
|
|
|
|
|
|
|
callback = Accuracy()
|
|
|
|
|
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
|
|
|
|
for data in eval_dataset.create_dict_iterator():
|
|
|
|
@ -231,9 +278,26 @@ def do_eval_standalone():
|
|
|
|
|
print("============== acc is {}".format(acc))
|
|
|
|
|
print("======================================")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
|
|
|
|
|
raise ValueError("do_train or do eval must have one be true, please confirm your config")
|
|
|
|
|
|
|
|
|
|
enable_loss_scale = True
|
|
|
|
|
if args_opt.device_target == "GPU":
|
|
|
|
|
if td_teacher_net_cfg.compute_type != mstype.float32:
|
|
|
|
|
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
|
|
|
|
td_teacher_net_cfg.compute_type = mstype.float32
|
|
|
|
|
if td_student_net_cfg.compute_type != mstype.float32:
|
|
|
|
|
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
|
|
|
|
td_student_net_cfg.compute_type = mstype.float32
|
|
|
|
|
# Both the forward and backward of the network are calculated using fp32,
|
|
|
|
|
# and the loss scale is not necessary
|
|
|
|
|
enable_loss_scale = False
|
|
|
|
|
|
|
|
|
|
td_teacher_net_cfg.seq_length = task.seq_length
|
|
|
|
|
td_student_net_cfg.seq_length = task.seq_length
|
|
|
|
|
|
|
|
|
|
if args_opt.do_train == "true":
|
|
|
|
|
# run predistill
|
|
|
|
|
run_predistill()
|
|
|
|
|