|
|
|
@ -77,6 +77,18 @@ class GradientClipByValue(BaseGradientClipAttr):
|
|
|
|
|
return param, new_grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradientClipByNorm(BaseGradientClipAttr):
|
|
|
|
|
def __init__(self, clip_norm):
|
|
|
|
|
self.clip_norm = clip_norm
|
|
|
|
|
|
|
|
|
|
def process_context(self, context, p_g):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def create_operators(self, param, grad):
|
|
|
|
|
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
|
|
|
|
|
return param, new_grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_gradient_clip_ops(param_grad):
|
|
|
|
|
context = dict()
|
|
|
|
|
create_op_callbacks = []
|
|
|
|
|