|
|
|
@ -31,6 +31,8 @@ from .transformer_model import TransformerModel
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1
|
|
|
|
|
GRADIENT_CLIP_VALUE = 5.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pylint: disable=consider-using-in
|
|
|
|
|
class ClipGradients(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Clip gradients.
|
|
|
|
@ -48,11 +50,12 @@ class ClipGradients(nn.Cell):
|
|
|
|
|
self.clip_by_norm = nn.ClipByNorm()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
|
|
|
|
|
def construct(self,
|
|
|
|
|
grads,
|
|
|
|
|
clip_type,
|
|
|
|
|
clip_value):
|
|
|
|
|
#return grads
|
|
|
|
|
# return grads
|
|
|
|
|
if clip_type != 0 and clip_type != 1:
|
|
|
|
|
return grads
|
|
|
|
|
|
|
|
|
@ -83,8 +86,8 @@ class TransformerTrainingLoss(nn.Cell):
|
|
|
|
|
super(TransformerTrainingLoss, self).__init__(auto_prefix=False)
|
|
|
|
|
self.vocab_size = config.vocab_size
|
|
|
|
|
self.onehot = P.OneHot()
|
|
|
|
|
self.on_value = Tensor(float(1-config.label_smoothing), mstype.float32)
|
|
|
|
|
self.off_value = Tensor(config.label_smoothing/float(self.vocab_size-1), mstype.float32)
|
|
|
|
|
self.on_value = Tensor(float(1 - config.label_smoothing), mstype.float32)
|
|
|
|
|
self.off_value = Tensor(config.label_smoothing / float(self.vocab_size - 1), mstype.float32)
|
|
|
|
|
self.reduce_sum = P.ReduceSum()
|
|
|
|
|
self.reduce_mean = P.ReduceMean()
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
@ -92,7 +95,7 @@ class TransformerTrainingLoss(nn.Cell):
|
|
|
|
|
self.flatten = P.Flatten()
|
|
|
|
|
self.neg = P.Neg()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.flat_shape = (config.batch_size*config.seq_length,)
|
|
|
|
|
self.flat_shape = (config.batch_size * config.seq_length,)
|
|
|
|
|
|
|
|
|
|
def construct(self, prediction_scores, label_ids, label_weights):
|
|
|
|
|
"""Defines the computation performed."""
|
|
|
|
@ -217,10 +220,12 @@ class TransformerTrainOneStepCell(nn.Cell):
|
|
|
|
|
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|
|
|
|
reciprocal = P.Reciprocal()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@grad_scale.register("Tensor", "Tensor")
|
|
|
|
|
def tensor_grad_scale(scale, grad):
|
|
|
|
|
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Encapsulation class of Transformer network training.
|
|
|
|
|