|
|
|
@ -46,12 +46,12 @@ class ErrorClipByValue(BaseErrorClipAttr):
|
|
|
|
|
self.min = min
|
|
|
|
|
|
|
|
|
|
def append_clip_op(self, block, grad_name):
|
|
|
|
|
block.append_op(
|
|
|
|
|
type="clip",
|
|
|
|
|
inputs={"X": grad_name},
|
|
|
|
|
outputs={"Out": grad_name},
|
|
|
|
|
attrs={"min": self.min,
|
|
|
|
|
"max": self.max})
|
|
|
|
|
clip_op_desc = block.desc.append_op()
|
|
|
|
|
clip_op_desc.set_type("clip")
|
|
|
|
|
clip_op_desc.set_input("X", [grad_name])
|
|
|
|
|
clip_op_desc.set_output("Out", [grad_name])
|
|
|
|
|
clip_op_desc.set_attr("min", self.min)
|
|
|
|
|
clip_op_desc.set_attr("max", self.max)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
The `BaseErrorClipAttr` have one main member functions: `append_clip_op(self, block, grad_name)`.
|
|
|
|
@ -80,6 +80,11 @@ def error_clip_callback(block, context):
|
|
|
|
|
op_desc.output_arg_names()):
|
|
|
|
|
fwd_var = block.var_recursive(grad_to_var[grad_n])
|
|
|
|
|
error_clip = getattr(fwd_var, "error_clip", None)
|
|
|
|
|
if not (error_clip is None or isinstance(error_clip,
|
|
|
|
|
BaseErrorClipAttr)):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
|
|
|
|
|
)
|
|
|
|
|
if error_clip is not None:
|
|
|
|
|
error_clip.append_clip_op(block, grad_n)
|
|
|
|
|
```
|
|
|
|
|