|
|
|
@ -1193,13 +1193,13 @@ class Block(object):
|
|
|
|
|
raise ValueError("Var {0} is not found recursively".format(name))
|
|
|
|
|
|
|
|
|
|
def _clear_block(self):
|
|
|
|
|
# TODO(minqiyang): move this to backward_hooks
|
|
|
|
|
self.desc._clear_block()
|
|
|
|
|
assert _in_imperative_mode()
|
|
|
|
|
|
|
|
|
|
for name in self.vars.keys():
|
|
|
|
|
assert self.vars[name].persistable
|
|
|
|
|
# TODO(minqiyang): move this to Variable and Operator's __del__
|
|
|
|
|
self.desc._clear_block()
|
|
|
|
|
|
|
|
|
|
del self.ops[:]
|
|
|
|
|
assert len(self.vars) == 0
|
|
|
|
|
assert len(self.ops) == 0
|
|
|
|
|
|
|
|
|
|
def all_parameters(self):
|
|
|
|
|
return list(self.iter_parameters())
|
|
|
|
@ -1337,26 +1337,13 @@ class Block(object):
|
|
|
|
|
#
|
|
|
|
|
# TODO(minqiyang): add op stop_gradient support in static mode too.
|
|
|
|
|
# currently, we only support stop_gradient in imperative mode.
|
|
|
|
|
self._trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
_imperative_tracer().trace_op(op,
|
|
|
|
|
kwargs.get("stop_gradient", False))
|
|
|
|
|
else:
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
def _trace_op(self, op, stop_gradient=False):
|
|
|
|
|
backward_refs = _imperative_tracer().trace(
|
|
|
|
|
op.iop, op.inputs, op.outputs, self.desc,
|
|
|
|
|
_imperative_current_expected_place_, stop_gradient)
|
|
|
|
|
|
|
|
|
|
# TODO(minqiyang): support backward_hooks to eager remove backward_refs
|
|
|
|
|
op.backward_refs = defaultdict(list)
|
|
|
|
|
for k, v in six.iteritems(op.inputs):
|
|
|
|
|
if k in backward_refs:
|
|
|
|
|
op.backward_refs[k] = op.inputs[k]
|
|
|
|
|
|
|
|
|
|
for k, v in six.iteritems(op.outputs):
|
|
|
|
|
if k in backward_refs:
|
|
|
|
|
op.backward_refs[k] = op.outputs[k]
|
|
|
|
|
|
|
|
|
|
def _insert_op(self, index, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Insert a Operator according to the giving arguments.
|
|
|
|
@ -1409,9 +1396,11 @@ class Block(object):
|
|
|
|
|
inputs=kwargs.get("inputs", None),
|
|
|
|
|
outputs=kwargs.get("outputs", None),
|
|
|
|
|
attrs=kwargs.get("attrs", None))
|
|
|
|
|
self.ops.insert(0, op)
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
self._trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
_imperative_tracer().trace_op(op,
|
|
|
|
|
kwargs.get("stop_gradient", False))
|
|
|
|
|
else:
|
|
|
|
|
self.ops.insert(0, op)
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
def _sync_with_cpp(self):
|
|
|
|
|