|
|
|
@ -590,7 +590,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
|
|
|
|
|
global_norm_var = layers.reduce_sum(global_norm_var)
|
|
|
|
|
global_norm_var = layers.sqrt(global_norm_var)
|
|
|
|
|
max_global_norm = layers.fill_constant(
|
|
|
|
|
shape=[1], dtype='float32', value=self.clip_norm)
|
|
|
|
|
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
|
|
|
|
|
clip_var = layers.elementwise_div(
|
|
|
|
|
x=max_global_norm,
|
|
|
|
|
y=layers.elementwise_max(
|
|
|
|
@ -635,7 +635,9 @@ class GradientClipByGlobalNorm(GradientClipBase):
|
|
|
|
|
global_norm_var = layers.sums(sum_square_list)
|
|
|
|
|
global_norm_var = layers.sqrt(x=global_norm_var)
|
|
|
|
|
max_global_norm = layers.fill_constant(
|
|
|
|
|
shape=[1], dtype="float32", value=self.clip_norm)
|
|
|
|
|
shape=[1],
|
|
|
|
|
dtype=global_norm_var.dtype,
|
|
|
|
|
value=self.clip_norm)
|
|
|
|
|
scale_var = layers.elementwise_div(
|
|
|
|
|
x=max_global_norm,
|
|
|
|
|
y=layers.elementwise_max(
|
|
|
|
@ -663,7 +665,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
|
|
|
|
|
context[self.group_name] = []
|
|
|
|
|
context[self.group_name + "_clip_value"] = self.clip_norm
|
|
|
|
|
context[self.group_name + "_clip"] = layers.fill_constant(
|
|
|
|
|
shape=[1], dtype="float32", value=self.clip_norm)
|
|
|
|
|
shape=[1], dtype=grad.dtype, value=self.clip_norm)
|
|
|
|
|
else:
|
|
|
|
|
if not self.clip_norm == context[self.group_name + "_clip_value"]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|