Add new optimizer THOR option to BERT pretrain script.

pull/12022/head
MingHan-Y 4 years ago
parent 856d6f58cf
commit 67a4c62b4b

@ -28,7 +28,8 @@ from mindspore.context import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore.train.train_thor import ConvertModelUtils
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR
from mindspore import log as logger
from mindspore.common import set_seed
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
@ -90,8 +91,27 @@ def _get_optimizer(args_opt, network):
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == "Thor":
from src.utils import get_bert_thor_lr, get_bert_thor_damping
lr = get_bert_thor_lr()
damping = get_bert_thor_damping()
split_indices = None
if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions:
split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
else:
split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions:
split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
else:
split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
split_indices=split_indices)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
format(cfg.optimizer))
return optimizer
@ -244,6 +264,8 @@ def run_pretrain():
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads)
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
frequency=cfg.Thor.frequency)
model.train(new_repeat_count, ds, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)

Loading…
Cancel
Save