|
|
|
@ -280,7 +280,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
|
|
|
|
|
group_scale_name = self.group_name + "_scale"
|
|
|
|
|
if group_scale_name not in self.context:
|
|
|
|
|
group_norm_var = layers.sums(input=self.context[self.group_name])
|
|
|
|
|
layers.sqrt(x=group_norm_var, out=group_norm_var)
|
|
|
|
|
group_norm_var = layers.sqrt(x=group_norm_var)
|
|
|
|
|
clip_var = self.context[self.group_name + "_clip"]
|
|
|
|
|
group_scale_var = layers.elementwise_div(
|
|
|
|
|
x=clip_var,
|
|
|
|
|