script update for bert

pull/3335/head
shibeiji 5 years ago
parent ca6da6751f
commit af4923123c

@ -117,8 +117,7 @@ def run_pretrain():
decay_params = list(filter(cfg.Lamb.decay_filter, params)) decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}, {'params': other_params}]
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum': elif cfg.optimizer == 'Momentum':
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
@ -133,8 +132,7 @@ def run_pretrain():
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}, {'params': other_params, 'weight_decay': 0.0}]
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else: else:

@ -22,7 +22,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
@ -55,7 +55,7 @@ class BertFinetuneCell(nn.Cell):
super(BertFinetuneCell, self).__init__(auto_prefix=False) super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation('grad',
get_by_list=True, get_by_list=True,
@ -158,7 +158,7 @@ class BertSquadCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertSquadCell, self).__init__(auto_prefix=False) super(BertSquadCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.reducer_flag = False self.reducer_flag = False

@ -21,7 +21,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation('grad',
get_by_list=True, get_by_list=True,

@ -133,7 +133,10 @@ class BertLearningRate(LearningRateSchedule):
""" """
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__() super(BertLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
@ -142,8 +145,11 @@ class BertLearningRate(LearningRateSchedule):
self.cast = P.Cast() self.cast = P.Cast()
def construct(self, global_step): def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step) decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr return lr

Loading…
Cancel
Save