|
|
|
@ -6,18 +6,9 @@ __all__ = ['GradientClipByValue', 'append_gradient_clip_ops']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseErrorClipAttr(object):
|
|
|
|
|
def create_clip_op_desc(self, grad_name):
|
|
|
|
|
def append_clip_op(self, block, grad_name):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
def prepend_clip_op_desc(self, op_descs):
|
|
|
|
|
grad_names = set()
|
|
|
|
|
for op_desc in op_descs:
|
|
|
|
|
grad_names.update(
|
|
|
|
|
filter(lambda n: n.find(core.grad_var_suffix()) != -1,
|
|
|
|
|
op_desc.output_arg_names()))
|
|
|
|
|
for n in grad_names:
|
|
|
|
|
op_descs.append(self.create_clip_op_desc(grad_name=n))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ErrorClipByValue(BaseErrorClipAttr):
|
|
|
|
|
def __init__(self, max, min=None):
|
|
|
|
@ -29,14 +20,25 @@ class ErrorClipByValue(BaseErrorClipAttr):
|
|
|
|
|
self.max = max
|
|
|
|
|
self.min = min
|
|
|
|
|
|
|
|
|
|
def create_clip_op_desc(self, grad_name):
|
|
|
|
|
desc = core.OpDesc()
|
|
|
|
|
desc.set_type("clip")
|
|
|
|
|
desc.set_input("X", grad_name)
|
|
|
|
|
desc.set_output("Out", grad_name)
|
|
|
|
|
desc.set_attr("min", self.min)
|
|
|
|
|
desc.set_attr("max", self.max)
|
|
|
|
|
return desc
|
|
|
|
|
def append_clip_op(self, block, grad_name):
|
|
|
|
|
block.append_op(
|
|
|
|
|
type="clip",
|
|
|
|
|
inputs={"X": grad_name},
|
|
|
|
|
outputs={"Out": grad_name},
|
|
|
|
|
attrs={"min": self.min,
|
|
|
|
|
"max": self.max})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def error_clip_callback(block, context):
|
|
|
|
|
# the context is a grad_to_var map
|
|
|
|
|
grad_to_var = context
|
|
|
|
|
op_desc = block.desc.op(block.desc.op_size() - 1)
|
|
|
|
|
for grad_n in filter(lambda n: grad_to_var.has_key(n),
|
|
|
|
|
op_desc.output_arg_names()):
|
|
|
|
|
fwd_var = block.var_recursive(grad_to_var[grad_n])
|
|
|
|
|
error_clip = getattr(fwd_var, "error_clip", None)
|
|
|
|
|
if error_clip is not None:
|
|
|
|
|
error_clip.append_clip_op(block, grad_n)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseGradientClipAttr(object):
|
|
|
|
|