|
|
|
@ -32,44 +32,31 @@ from .bert_model import BertModel
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1
|
|
|
|
|
GRADIENT_CLIP_VALUE = 1.0
|
|
|
|
|
|
|
|
|
|
_nn_clip_by_norm = nn.ClipByNorm()
|
|
|
|
|
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
|
|
|
|
@clip_grad.register("Number", "Number", "Tensor")
|
|
|
|
|
|
|
|
|
|
class ClipGradients(nn.Cell):
|
|
|
|
|
def _clip_grad(clip_type, clip_value, grad):
|
|
|
|
|
"""
|
|
|
|
|
Clip gradients.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
grads (tuple[Tensor]): Gradients.
|
|
|
|
|
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
|
|
|
|
clip_value (float): Specifies how much to clip.
|
|
|
|
|
grad (tuple[Tensor]): Gradients.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
tuple[Tensor], clipped gradients.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(ClipGradients, self).__init__()
|
|
|
|
|
self.clip_by_norm = nn.ClipByNorm()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
|
|
|
|
|
def construct(self,
|
|
|
|
|
grads,
|
|
|
|
|
clip_type,
|
|
|
|
|
clip_value):
|
|
|
|
|
if clip_type != 0 and clip_type != 1:
|
|
|
|
|
return grads
|
|
|
|
|
|
|
|
|
|
new_grads = ()
|
|
|
|
|
for grad in grads:
|
|
|
|
|
dt = self.dtype(grad)
|
|
|
|
|
return grad
|
|
|
|
|
dt = F.dtype(grad)
|
|
|
|
|
if clip_type == 0:
|
|
|
|
|
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
|
|
|
|
self.cast(F.tuple_to_array((clip_value,)), dt))
|
|
|
|
|
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
|
|
|
|
F.cast(F.tuple_to_array((clip_value,)), dt))
|
|
|
|
|
else:
|
|
|
|
|
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
|
|
|
|
new_grads = new_grads + (t,)
|
|
|
|
|
|
|
|
|
|
return new_grads
|
|
|
|
|
|
|
|
|
|
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
|
|
|
|
return new_grad
|
|
|
|
|
|
|
|
|
|
class GetMaskedLMOutput(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
@ -294,8 +281,8 @@ class BertTrainOneStepCell(nn.Cell):
|
|
|
|
|
degree = get_group_size()
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
|
|
|
|
|
|
|
|
self.clip_gradients = ClipGradients()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
|
|
def set_sens(self, value):
|
|
|
|
|
self.sens = value
|
|
|
|
@ -327,7 +314,7 @@ class BertTrainOneStepCell(nn.Cell):
|
|
|
|
|
masked_lm_weights,
|
|
|
|
|
self.cast(F.tuple_to_array((self.sens,)),
|
|
|
|
|
mstype.float32))
|
|
|
|
|
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
@ -376,7 +363,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
degree = get_group_size()
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
|
|
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
|
|
|
self.clip_gradients = ClipGradients()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.alloc_status = P.NPUAllocFloatStatus()
|
|
|
|
|
self.get_status = P.NPUGetFloatStatus()
|
|
|
|
@ -427,7 +413,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
self.cast(scaling_sens,
|
|
|
|
|
mstype.float32))
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
|
|
|
|
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
self.get_status(init)
|
|
|
|
|