|
|
|
@ -64,7 +64,6 @@ def run_pretrain():
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
context.set_context(variable_memory_max_size="30GB")
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
if args_opt.device_target == 'Ascend':
|
|
|
|
@ -99,47 +98,49 @@ def run_pretrain():
|
|
|
|
|
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
|
|
|
|
bert_net_cfg.compute_type = mstype.float32
|
|
|
|
|
|
|
|
|
|
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
|
|
|
|
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
|
|
|
|
|
|
|
|
|
ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle,
|
|
|
|
|
args_opt.enable_data_sink, args_opt.data_sink_steps,
|
|
|
|
|
args_opt.data_dir, args_opt.schema_dir)
|
|
|
|
|
new_repeat_count = args_opt.epoch_size
|
|
|
|
|
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
|
|
|
|
|
if args_opt.train_steps > 0:
|
|
|
|
|
new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps)
|
|
|
|
|
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
|
|
|
|
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
|
|
|
|
else:
|
|
|
|
|
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
|
|
|
|
|
|
|
|
|
if cfg.optimizer == 'Lamb':
|
|
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
|
|
|
|
end_learning_rate=cfg.Lamb.end_learning_rate,
|
|
|
|
|
warmup_steps=cfg.Lamb.warmup_steps,
|
|
|
|
|
decay_steps=ds.get_dataset_size() * new_repeat_count,
|
|
|
|
|
decay_steps=args_opt.train_steps,
|
|
|
|
|
power=cfg.Lamb.power)
|
|
|
|
|
params = net_with_loss.trainable_params()
|
|
|
|
|
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
|
|
|
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
|
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
|
|
|
|
{'params': other_params}]
|
|
|
|
|
{'params': other_params},
|
|
|
|
|
{'order_params': params}]
|
|
|
|
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
|
|
|
|
elif cfg.optimizer == 'Momentum':
|
|
|
|
|
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
|
|
|
|
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
|
|
|
|
momentum=cfg.Momentum.momentum)
|
|
|
|
|
elif cfg.optimizer == 'AdamWeightDecay':
|
|
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
|
|
|
|
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
|
|
|
|
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
|
|
|
|
decay_steps=ds.get_dataset_size() * new_repeat_count,
|
|
|
|
|
decay_steps=args_opt.train_steps,
|
|
|
|
|
power=cfg.AdamWeightDecay.power)
|
|
|
|
|
params = net_with_loss.trainable_params()
|
|
|
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
|
|
|
|
other_params = list(filter(lambda x: x not in decay_params, params))
|
|
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
|
|
|
|
{'params': other_params, 'weight_decay': 0.0}]
|
|
|
|
|
{'params': other_params, 'weight_decay': 0.0},
|
|
|
|
|
{'order_params': params}]
|
|
|
|
|
|
|
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
|
|
|
|
format(cfg.optimizer))
|
|
|
|
|
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
|
|
|
|
|
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
|
|
|
|
if args_opt.enable_save_ckpt == "true":
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
|
|
|
|
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
|
|
|
@ -148,19 +149,22 @@ def run_pretrain():
|
|
|
|
|
|
|
|
|
|
if args_opt.load_checkpoint_path:
|
|
|
|
|
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
|
|
|
|
load_param_into_net(netwithloss, param_dict)
|
|
|
|
|
load_param_into_net(net_with_loss, param_dict)
|
|
|
|
|
|
|
|
|
|
if args_opt.enable_lossscale == "true":
|
|
|
|
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
|
|
|
|
scale_factor=cfg.scale_factor,
|
|
|
|
|
scale_window=cfg.scale_window)
|
|
|
|
|
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
|
|
|
|
scale_update_cell=update_cell)
|
|
|
|
|
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
|
|
|
|
scale_update_cell=update_cell)
|
|
|
|
|
else:
|
|
|
|
|
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
|
|
|
|
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
|
|
|
|
|
|
|
|
|
model = Model(net_with_grads)
|
|
|
|
|
model.train(new_repeat_count, ds, callbacks=callback,
|
|
|
|
|
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Model(netwithgrads)
|
|
|
|
|
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
numpy.random.seed(0)
|
|
|
|
|
run_pretrain()
|
|
|
|
|