|
|
|
@ -24,10 +24,6 @@ __all__ = ['Tracer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def release_op(op):
|
|
|
|
|
import gc
|
|
|
|
|
assert len(
|
|
|
|
|
gc.get_referrers(framework._imperative_tracer()._ops[
|
|
|
|
|
op._trace_id])) == 1
|
|
|
|
|
del framework._imperative_tracer()._ops[op._trace_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -59,6 +55,8 @@ class Tracer(core.Tracer):
|
|
|
|
|
if len(backward_refs) > 0:
|
|
|
|
|
op.iop.register_backward_hooks(release_op)
|
|
|
|
|
|
|
|
|
|
# TODO(minqiyang): remove all inputs and outputs after seperate
|
|
|
|
|
# var and grad
|
|
|
|
|
op.backward_refs = defaultdict(list)
|
|
|
|
|
for k, v in six.iteritems(op.inputs):
|
|
|
|
|
if k in backward_refs:
|
|
|
|
|