!3335 scripts updation for bert to debug failures of pre-training processing when warmup step was set to be zero

Merge pull request !3335 from shibeiji/bert_script_debug
pull/3335/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit db1a1fb88b

@ -117,8 +117,7 @@ def run_pretrain():
decay_params = list(filter(cfg.Lamb.decay_filter, params)) decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}, {'params': other_params}]
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum': elif cfg.optimizer == 'Momentum':
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
@ -133,8 +132,7 @@ def run_pretrain():
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}, {'params': other_params, 'weight_decay': 0.0}]
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else: else:

@ -22,7 +22,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
@ -55,7 +55,7 @@ class BertFinetuneCell(nn.Cell):
super(BertFinetuneCell, self).__init__(auto_prefix=False) super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation('grad',
get_by_list=True, get_by_list=True,
@ -158,7 +158,7 @@ class BertSquadCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertSquadCell, self).__init__(auto_prefix=False) super(BertSquadCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.reducer_flag = False self.reducer_flag = False

@ -21,7 +21,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation('grad',
get_by_list=True, get_by_list=True,

@ -133,7 +133,10 @@ class BertLearningRate(LearningRateSchedule):
""" """
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__() super(BertLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
@ -142,8 +145,11 @@ class BertLearningRate(LearningRateSchedule):
self.cast = P.Cast() self.cast = P.Cast()
def construct(self, global_step): def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step) decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr return lr

Loading…
Cancel
Save