|
|
|
@ -25,12 +25,11 @@ from src.config import cfg
|
|
|
|
|
from src.dataset import create_bert_dataset
|
|
|
|
|
from src.lr_generator import get_bert_lr, get_bert_damping
|
|
|
|
|
from src.model_thor import Model
|
|
|
|
|
from src.utils import LossCallBack, BertLearningRate
|
|
|
|
|
from src.utils import LossCallBack
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.communication.management as D
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
|
|
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
@ -68,38 +67,8 @@ def _set_bert_all_reduce_split():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_optimizer(args_opt, network):
|
|
|
|
|
"""get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
|
|
|
|
|
if cfg.optimizer == 'Lamb':
|
|
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
|
|
|
|
end_learning_rate=cfg.Lamb.end_learning_rate,
|
|
|
|
|
warmup_steps=cfg.Lamb.warmup_steps,
|
|
|
|
|
decay_steps=args_opt.train_steps,
|
|
|
|
|
power=cfg.Lamb.power)
|
|
|
|
|
params = network.trainable_params()
|
|
|
|
|
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
|
|
|
|
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
|
|
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
|
|
|
|
{'params': other_params},
|
|
|
|
|
{'order_params': params}]
|
|
|
|
|
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
|
|
|
|
elif cfg.optimizer == 'Momentum':
|
|
|
|
|
optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
|
|
|
|
momentum=cfg.Momentum.momentum)
|
|
|
|
|
elif cfg.optimizer == 'AdamWeightDecay':
|
|
|
|
|
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
|
|
|
|
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
|
|
|
|
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
|
|
|
|
decay_steps=args_opt.train_steps,
|
|
|
|
|
power=cfg.AdamWeightDecay.power)
|
|
|
|
|
params = network.trainable_params()
|
|
|
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
|
|
|
|
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
|
|
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
|
|
|
|
{'params': other_params, 'weight_decay': 0.0},
|
|
|
|
|
{'order_params': params}]
|
|
|
|
|
|
|
|
|
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
|
|
|
elif cfg.optimizer == "Thor":
|
|
|
|
|
"""get thor optimizer."""
|
|
|
|
|
if cfg.optimizer == "Thor":
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
from src.thor_for_bert_arg import THOR
|
|
|
|
|
else:
|
|
|
|
@ -112,8 +81,7 @@ def _get_optimizer(args_opt, network):
|
|
|
|
|
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
|
|
|
|
bert_net_cfg.batch_size, damping)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
|
|
|
|
|
format(cfg.optimizer))
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer))
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|