From 67a4c62b4b0ace380417ac2f02e321cb9984fd65 Mon Sep 17 00:00:00 2001 From: MingHan-Y Date: Wed, 3 Feb 2021 09:51:25 +0800 Subject: [PATCH] Add new optimizer THOR option to BERT pretrain script. --- model_zoo/official/nlp/bert/run_pretrain.py | 26 +++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 7703bcae76..78fd79a787 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -28,7 +28,8 @@ from mindspore.context import ParallelMode from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay +from mindspore.train.train_thor import ConvertModelUtils +from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR from mindspore import log as logger from mindspore.common import set_seed from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \ @@ -90,8 +91,27 @@ def _get_optimizer(args_opt, network): optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) + elif cfg.optimizer == "Thor": + from src.utils import get_bert_thor_lr, get_bert_thor_damping + lr = get_bert_thor_lr() + damping = get_bert_thor_damping() + split_indices = None + if bert_net_cfg.num_hidden_layers == 12: + if bert_net_cfg.use_relative_positions: + split_indices = [29, 58, 87, 116, 145, 174, 203, 217] + else: + split_indices = [28, 55, 82, 109, 136, 163, 190, 205] + elif bert_net_cfg.num_hidden_layers == 24: + if bert_net_cfg.use_relative_positions: + split_indices = [30, 90, 150, 210, 270, 330, 390, 421] + else: + split_indices = [38, 93, 148, 203, 258, 313, 368, 397] + optimizer = THOR(network, lr, damping, cfg.Thor.momentum, + cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size, + decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), + split_indices=split_indices) else: - raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". + raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". format(cfg.optimizer)) return optimizer @@ -244,6 +264,8 @@ def run_pretrain(): net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) + model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer, + frequency=cfg.Thor.frequency) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)