|
|
@ -40,6 +40,11 @@ def error_clip_callback(block, context):
|
|
|
|
op_desc.output_arg_names()):
|
|
|
|
op_desc.output_arg_names()):
|
|
|
|
fwd_var = block.var_recursive(grad_to_var[grad_n])
|
|
|
|
fwd_var = block.var_recursive(grad_to_var[grad_n])
|
|
|
|
error_clip = getattr(fwd_var, "error_clip", None)
|
|
|
|
error_clip = getattr(fwd_var, "error_clip", None)
|
|
|
|
|
|
|
|
if not (error_clip is None or isinstance(error_clip,
|
|
|
|
|
|
|
|
BaseErrorClipAttr)):
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
|
|
|
|
|
|
|
|
)
|
|
|
|
if error_clip is not None:
|
|
|
|
if error_clip is not None:
|
|
|
|
error_clip.append_clip_op(block, grad_n)
|
|
|
|
error_clip.append_clip_op(block, grad_n)
|
|
|
|
|
|
|
|
|
|
|
|