|
|
|
@ -77,7 +77,7 @@ class Cell:
|
|
|
|
|
if flags:
|
|
|
|
|
self.add_flags(**flags)
|
|
|
|
|
self._backward_hook = None
|
|
|
|
|
self._enable_hook = False
|
|
|
|
|
self.enable_hook = False
|
|
|
|
|
self._bprop_debug = False
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@ -97,10 +97,24 @@ class Cell:
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def bprop_debug(self):
|
|
|
|
|
"""
|
|
|
|
|
Get whether cell custom bprop debug is enabled.
|
|
|
|
|
"""
|
|
|
|
|
return self._bprop_debug
|
|
|
|
|
|
|
|
|
|
@bprop_debug.setter
|
|
|
|
|
def bprop_debug(self, value):
|
|
|
|
|
"""
|
|
|
|
|
Set whether to enable cell custom bprop debug.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
When bprop is defined in cell, the bprop function will be executed
|
|
|
|
|
in python interpreter when bprop debug is true, and will be parsed
|
|
|
|
|
and add to graph when bprop debug is false.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
value (bool): Specifies whether to enable bprop debug. Default: False.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(value, bool):
|
|
|
|
|
raise TypeError("'bprop debug' value must be bool type.")
|
|
|
|
|
self._bprop_debug = value
|
|
|
|
@ -755,17 +769,19 @@ class Cell:
|
|
|
|
|
outputs = self._backward_hook(inputs)
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def enable_hook(self):
|
|
|
|
|
"""Whether the cell register hook function"""
|
|
|
|
|
return self._enable_hook
|
|
|
|
|
|
|
|
|
|
def register_backward_hook(self, fn):
|
|
|
|
|
"""
|
|
|
|
|
Set the cell backward hook function.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
fn should be defined as following code shows, `cell_name` is the name of registered cell,
|
|
|
|
|
`grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to
|
|
|
|
|
next cell or primitve, which may be modified and return.
|
|
|
|
|
>>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fn (function): Specifies the hook function with grad as input.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
|
|
|
|
|
self._enable_hook = True
|
|
|
|
|