|
|
|
@ -28,7 +28,8 @@ from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
|
|
|
|
from mindspore.train.train_thor import ConvertModelUtils
|
|
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore.common import set_seed
|
|
|
|
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
|
|
|
@ -90,8 +91,27 @@ def _get_optimizer(args_opt, network):
|
|
|
|
|
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
|
|
|
else:
|
|
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
|
|
|
elif cfg.optimizer == "Thor":
|
|
|
|
|
from src.utils import get_bert_thor_lr, get_bert_thor_damping
|
|
|
|
|
lr = get_bert_thor_lr()
|
|
|
|
|
damping = get_bert_thor_damping()
|
|
|
|
|
split_indices = None
|
|
|
|
|
if bert_net_cfg.num_hidden_layers == 12:
|
|
|
|
|
if bert_net_cfg.use_relative_positions:
|
|
|
|
|
split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
|
|
|
|
|
else:
|
|
|
|
|
split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
|
|
|
|
|
elif bert_net_cfg.num_hidden_layers == 24:
|
|
|
|
|
if bert_net_cfg.use_relative_positions:
|
|
|
|
|
split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
|
|
|
|
|
else:
|
|
|
|
|
split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
|
|
|
|
|
optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
|
|
|
|
|
cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
|
|
|
|
|
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
|
|
|
|
split_indices=split_indices)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
|
|
|
|
format(cfg.optimizer))
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
@ -244,6 +264,8 @@ def run_pretrain():
|
|
|
|
|
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
|
|
|
|
|
|
|
|
|
model = Model(net_with_grads)
|
|
|
|
|
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
|
|
|
|
|
frequency=cfg.Thor.frequency)
|
|
|
|
|
model.train(new_repeat_count, ds, callbacks=callback,
|
|
|
|
|
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
|
|
|
|
|
|
|
|
|
|