|
|
|
@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
from mindspore.nn.optim import AdamWeightDecay
|
|
|
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from src.dataset import create_tinybert_dataset
|
|
|
|
|
from src.dataset import create_tinybert_dataset, DataType
|
|
|
|
|
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
|
|
|
|
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
|
|
|
|
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
|
|
|
|
@ -55,6 +55,7 @@ def run_general_distill():
|
|
|
|
|
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
|
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
|
|
|
|
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
@ -99,8 +100,15 @@ def run_general_distill():
|
|
|
|
|
student_config=bert_student_net_cfg,
|
|
|
|
|
is_training=True, use_one_hot_embeddings=False)
|
|
|
|
|
|
|
|
|
|
if args_opt.dataset_type == "tfrecord":
|
|
|
|
|
dataset_type = DataType.TFRECORD
|
|
|
|
|
elif arg_opt.dataset_type == "mindrecord":
|
|
|
|
|
dataset_type = DataType.MINDRECORD
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("dataset format is not supported yet")
|
|
|
|
|
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
|
|
|
|
|
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
|
|
|
|
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir,
|
|
|
|
|
data_type=dataset_type)
|
|
|
|
|
dataset_size = dataset.get_dataset_size()
|
|
|
|
|
print('dataset size: ', dataset_size)
|
|
|
|
|
print("dataset repeatcount: ", dataset.get_repeat_count())
|
|
|
|
|