diff --git a/model_zoo/official/nlp/bert_thor/run_pretrain.py b/model_zoo/official/nlp/bert_thor/run_pretrain.py index 410878e08a..b7a33b22a5 100644 --- a/model_zoo/official/nlp/bert_thor/run_pretrain.py +++ b/model_zoo/official/nlp/bert_thor/run_pretrain.py @@ -25,12 +25,11 @@ from src.config import cfg from src.dataset import create_bert_dataset from src.lr_generator import get_bert_lr, get_bert_damping from src.model_thor import Model -from src.utils import LossCallBack, BertLearningRate +from src.utils import LossCallBack import mindspore.common.dtype as mstype import mindspore.communication.management as D from mindspore import context from mindspore import log as logger -from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.context import ParallelMode @@ -68,38 +67,8 @@ def _set_bert_all_reduce_split(): def _get_optimizer(args_opt, network): - """get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor.""" - if cfg.optimizer == 'Lamb': - lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, - end_learning_rate=cfg.Lamb.end_learning_rate, - warmup_steps=cfg.Lamb.warmup_steps, - decay_steps=args_opt.train_steps, - power=cfg.Lamb.power) - params = network.trainable_params() - decay_params = list(filter(cfg.Lamb.decay_filter, params)) - other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) - group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, - {'params': other_params}, - {'order_params': params}] - optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) - elif cfg.optimizer == 'Momentum': - optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate, - momentum=cfg.Momentum.momentum) - elif cfg.optimizer == 'AdamWeightDecay': - lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, - end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, - warmup_steps=cfg.AdamWeightDecay.warmup_steps, - decay_steps=args_opt.train_steps, - power=cfg.AdamWeightDecay.power) - params = network.trainable_params() - decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) - group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, - {'params': other_params, 'weight_decay': 0.0}, - {'order_params': params}] - - optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) - elif cfg.optimizer == "Thor": + """get thor optimizer.""" + if cfg.optimizer == "Thor": if args_opt.distribute == "true": from src.thor_for_bert_arg import THOR else: @@ -112,8 +81,7 @@ def _get_optimizer(args_opt, network): cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, bert_net_cfg.batch_size, damping) else: - raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". - format(cfg.optimizer)) + raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer)) return optimizer diff --git a/model_zoo/official/nlp/bert_thor/src/config.py b/model_zoo/official/nlp/bert_thor/src/config.py index b17eecd0f6..6613831a86 100644 --- a/model_zoo/official/nlp/bert_thor/src/config.py +++ b/model_zoo/official/nlp/bert_thor/src/config.py @@ -20,28 +20,6 @@ from easydict import EasyDict as edict cfg = edict({ 'bert_network': 'large', 'optimizer': 'Thor', - 'AdamWeightDecay': edict({ - 'learning_rate': 3e-5, - 'end_learning_rate': 1e-10, - 'power': 5.0, - 'weight_decay': 1e-5, - 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), - 'eps': 1e-6, - 'warmup_steps': 10000, - }), - 'Lamb': edict({ - 'learning_rate': 3e-5, - 'end_learning_rate': 1e-10, - 'power': 10.0, - 'warmup_steps': 10000, - 'weight_decay': 0.01, - 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), - 'eps': 1e-6, - }), - 'Momentum': edict({ - 'learning_rate': 2e-5, - 'momentum': 0.9, - }), 'Thor': edict({ 'momentum': 0.9, 'weight_decay': 5e-4,