You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/nlp/tinybert/run_task_distill.py

326 lines
17 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""task distill script"""
import os
import re
import argparse
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.callback import TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay
from mindspore import log as logger
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
from src.tinybert_model import BertModelCLS
_cur_dir = os.getcwd()
td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt')
td_phase2_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase2_save_ckpt')
if not os.path.exists(td_phase1_save_ckpt_dir):
os.makedirs(td_phase1_save_ckpt_dir)
if not os.path.exists(td_phase2_save_ckpt_dir):
os.makedirs(td_phase2_save_ckpt_dir)
def parse_args():
"""
parse args
"""
parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
5 years ago
parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"],
help="Do train task, default is true.")
parser.add_argument("--do_eval", type=str, default="true", choices=["true", "false"],
help="Do eval task, default is true.")
parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
help="Epoch size for td phase 1, default is 10.")
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
5 years ago
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
help="Enable data sink, default is true.")
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.")
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--load_gd_ckpt_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--load_td1_ckpt_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
help="The name of the task to train.")
parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord")
args = parser.parse_args()
return args
args_opt = parse_args()
DEFAULT_NUM_LABELS = 2
DEFAULT_SEQ_LENGTH = 128
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}}
class Task:
"""
Encapsulation class of get the task parameter.
"""
def __init__(self, task_name):
self.task_name = task_name
@property
def num_labels(self):
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
return task_params[self.task_name]["num_labels"]
return DEFAULT_NUM_LABELS
@property
def seq_length(self):
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
return task_params[self.task_name]["seq_length"]
return DEFAULT_SEQ_LENGTH
task = Task(args_opt.task_name)
def run_predistill():
"""
run predistill
"""
cfg = phase1_cfg
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False)
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification',
num_labels=task.num_labels, is_predistill=True)
rank = 0
device_num = 1
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif args_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir,
data_type=dataset_type)
dataset_size = dataset.get_dataset_size()
print('td1 dataset size: ', dataset_size)
print('td1 dataset repeatcount: ', dataset.get_repeat_count())
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
repeat_count = args_opt.td_phase1_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg.optimizer_cfg
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(dataset_size / 10),
decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
power=optimizer_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase1_save_ckpt_dir)]
if enable_loss_scale:
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
else:
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
sink_size=args_opt.data_sink_steps)
def run_task_distill(ckpt_file):
"""
run task distill
"""
if ckpt_file == '':
raise ValueError("Student ckpt file should not be None")
cfg = phase2_cfg
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = ckpt_file
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification',
num_labels=task.num_labels, is_predistill=False)
rank = 0
device_num = 1
train_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir)
dataset_size = train_dataset.get_dataset_size()
print('td2 train dataset size: ', dataset_size)
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
repeat_count = args_opt.td_phase2_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg.optimizer_cfg
lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
power=optimizer_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir)
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
if args_opt.do_eval.lower() == "true":
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
EvalCallBack(netwithloss.bert, eval_dataset)]
else:
callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase2_save_ckpt_dir)]
if enable_loss_scale:
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
else:
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads)
model.train(repeat_count, train_dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
sink_size=args_opt.data_sink_steps)
def do_eval_standalone():
"""
do eval standalone
"""
ckpt_file = args_opt.load_td1_ckpt_path
if ckpt_file == '':
raise ValueError("Student ckpt file should not be None")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(ckpt_file)
new_param_dict = {}
for key, value in param_dict.items():
new_key = re.sub('tinybert_', 'bert_', key)
new_key = re.sub('^bert.', '', new_key)
new_param_dict[new_key] = value
load_param_into_net(eval_model, new_param_dict)
eval_model.set_train(False)
eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size,
device_num=1, rank=0, do_shuffle="false",
data_dir=args_opt.eval_data_dir,
schema_dir=args_opt.schema_dir)
print('eval dataset size: ', eval_dataset.get_dataset_size())
print('eval dataset batch size: ', eval_dataset.get_batch_size())
callback = Accuracy()
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in eval_dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, token_type_id, label_ids = input_data
logits = eval_model(input_ids, token_type_id, input_mask)
callback.update(logits[3], label_ids)
acc = callback.acc_num / callback.total_num
print("======================================")
print("============== acc is {}".format(acc))
print("======================================")
if __name__ == '__main__':
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
raise ValueError("do_train or do eval must have one be true, please confirm your config")
enable_loss_scale = True
if args_opt.device_target == "GPU":
if td_student_net_cfg.compute_type != mstype.float32:
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
td_student_net_cfg.compute_type = mstype.float32
# Backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale = False
td_teacher_net_cfg.seq_length = task.seq_length
td_student_net_cfg.seq_length = task.seq_length
if args_opt.do_train == "true":
# run predistill
run_predistill()
lists = os.listdir(td_phase1_save_ckpt_dir)
if lists:
lists.sort(key=lambda fn: os.path.getmtime(td_phase1_save_ckpt_dir+'/'+fn))
name_ext = os.path.splitext(lists[-1])
if name_ext[-1] != ".ckpt":
raise ValueError("Invalid file, checkpoint file should be .ckpt file")
newest_ckpt_file = os.path.join(td_phase1_save_ckpt_dir, lists[-1])
# run task distill
run_task_distill(newest_ckpt_file)
else:
raise ValueError("Checkpoint file not exists, please make sure ckpt file has been saved")
else:
do_eval_standalone()