|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
"""debug_ops"""
|
|
|
|
|
from types import FunctionType
|
|
|
|
|
from types import FunctionType, MethodType
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer):
|
|
|
|
|
super(HookBackward, self).__init__(self.__class__.__name__)
|
|
|
|
|
self.add_prim_attr("cell_id", cell_id)
|
|
|
|
|
self.init_attrs["cell_id"] = cell_id
|
|
|
|
|
if not isinstance(hook_fn, FunctionType):
|
|
|
|
|
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
|
|
|
raise TypeError("Hook function should be python function type.")
|
|
|
|
|
self.register_hook(hook_fn)
|
|
|
|
|
self.cell_id = cell_id
|
|
|
|
|