|
|
|
@ -376,15 +376,17 @@ class Variable(object):
|
|
|
|
|
# get_capacity is implemented
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
self.block.vars[name] = self
|
|
|
|
|
self.op = None
|
|
|
|
|
self.stop_gradient = stop_gradient
|
|
|
|
|
self.is_data = is_data
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
# record vars in tracer rather than blocks
|
|
|
|
|
self._ivar = kwargs.get("ivar", None)
|
|
|
|
|
if not self._ivar:
|
|
|
|
|
self._ivar = core.VarBase(stop_gradient)
|
|
|
|
|
self._ivar.desc = self.desc
|
|
|
|
|
else:
|
|
|
|
|
self.block.vars[name] = self
|
|
|
|
|
self.op = None
|
|
|
|
|
self.stop_gradient = stop_gradient
|
|
|
|
|
self.is_data = is_data
|
|
|
|
|
|
|
|
|
|
def _numpy(self):
|
|
|
|
|
new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
|
|
|
|
@ -727,6 +729,7 @@ class Operator(object):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
self.iop = core.OpBase()
|
|
|
|
|
self.iop.desc = self.desc
|
|
|
|
|
|
|
|
|
|
self.inputs = defaultdict(list)
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
for k, v in six.iteritems(inputs):
|
|
|
|
@ -734,6 +737,7 @@ class Operator(object):
|
|
|
|
|
self.inputs[k].append(v._ivar)
|
|
|
|
|
elif isinstance(v, list) or isinstance(v, tuple):
|
|
|
|
|
self.inputs[k].extend([var._ivar for var in v])
|
|
|
|
|
|
|
|
|
|
self.outputs = defaultdict(list)
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
for k, v in six.iteritems(outputs):
|
|
|
|
@ -1186,8 +1190,8 @@ class Block(object):
|
|
|
|
|
def _clear_block(self):
|
|
|
|
|
self.desc._clear_block()
|
|
|
|
|
|
|
|
|
|
for name, var in self.vars.items():
|
|
|
|
|
if not var.persistable:
|
|
|
|
|
for name in self.vars.keys():
|
|
|
|
|
if not self.vars[name].persistable:
|
|
|
|
|
del self.vars[name]
|
|
|
|
|
|
|
|
|
|
del self.ops[:]
|
|
|
|
@ -1322,18 +1326,34 @@ class Block(object):
|
|
|
|
|
inputs=kwargs.get("inputs", None),
|
|
|
|
|
outputs=kwargs.get("outputs", None),
|
|
|
|
|
attrs=kwargs.get("attrs", None))
|
|
|
|
|
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
# record ops in tracer rather than blocks
|
|
|
|
|
#
|
|
|
|
|
# TODO(minqiyang): add op stop_gradient support in static mode too.
|
|
|
|
|
# currently, we only support stop_gradient in imperative mode.
|
|
|
|
|
self._trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
|
|
|
|
|
# TODO(minqiyang): add stop_gradient support in static mode too.
|
|
|
|
|
# currently, we only support stop_gradient in imperative mode.
|
|
|
|
|
self._trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
def _trace_op(self, op, stop_gradient=False):
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc,
|
|
|
|
|
_imperative_current_expected_place_,
|
|
|
|
|
stop_gradient)
|
|
|
|
|
backward_refs = _imperative_tracer().trace(
|
|
|
|
|
op.iop, op.inputs, op.outputs, self.desc,
|
|
|
|
|
_imperative_current_expected_place_, stop_gradient)
|
|
|
|
|
print("backward_refs", backward_refs)
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
# TODO(minqiyang): support backward hooks to eager remove backward_refs
|
|
|
|
|
op.backward_refs = defaultdict(list)
|
|
|
|
|
for k, v in six.iteritems(op.inputs):
|
|
|
|
|
if k in backward_refs:
|
|
|
|
|
op.backward_refs[k] = op.inputs[k]
|
|
|
|
|
|
|
|
|
|
for k, v in six.iteritems(op.outputs):
|
|
|
|
|
if k in backward_refs:
|
|
|
|
|
op.backward_refs[k] = op.outputs[k]
|
|
|
|
|
|
|
|
|
|
def _insert_op(self, index, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -1388,7 +1408,8 @@ class Block(object):
|
|
|
|
|
outputs=kwargs.get("outputs", None),
|
|
|
|
|
attrs=kwargs.get("attrs", None))
|
|
|
|
|
self.ops.insert(0, op)
|
|
|
|
|
self._trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
self._trace_op(op, kwargs.get("stop_gradient", False))
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
def _sync_with_cpp(self):
|
|
|
|
|