|
|
|
@ -54,47 +54,24 @@ class Tracer(core.Tracer):
|
|
|
|
|
self._trace_id = 0
|
|
|
|
|
|
|
|
|
|
def trace_op(self, op, inputs, outputs, stop_gradient=False):
|
|
|
|
|
# TODO(minqiyang): remove this line after we take apart all
|
|
|
|
|
# backward grads and forward variables
|
|
|
|
|
if self._train_mode:
|
|
|
|
|
op.inputs = inputs
|
|
|
|
|
inps = defaultdict(list)
|
|
|
|
|
for k, vars in six.iteritems(inputs):
|
|
|
|
|
if isinstance(vars, framework.Variable):
|
|
|
|
|
inps[k].append(vars._ivar)
|
|
|
|
|
elif isinstance(vars, list) or isinstance(vars, tuple):
|
|
|
|
|
for var in vars:
|
|
|
|
|
inps[k].append(var._ivar)
|
|
|
|
|
|
|
|
|
|
op.outputs = outputs
|
|
|
|
|
outs = defaultdict(list)
|
|
|
|
|
for k, vars in six.iteritems(outputs):
|
|
|
|
|
if isinstance(vars, framework.Variable):
|
|
|
|
|
outs[k].append(vars._ivar)
|
|
|
|
|
elif isinstance(vars, list) or isinstance(vars, tuple):
|
|
|
|
|
for var in vars:
|
|
|
|
|
outs[k].append(var._ivar)
|
|
|
|
|
else:
|
|
|
|
|
inps = defaultdict(list)
|
|
|
|
|
for k, vars in six.iteritems(inputs):
|
|
|
|
|
if isinstance(vars, framework.Variable):
|
|
|
|
|
op.previous_ops.append(vars.op)
|
|
|
|
|
inps[k].append(vars._ivar)
|
|
|
|
|
elif isinstance(vars, list) or isinstance(vars, tuple):
|
|
|
|
|
for var in vars:
|
|
|
|
|
op.previous_ops.append(var.op)
|
|
|
|
|
inps[k].append(var._ivar)
|
|
|
|
|
|
|
|
|
|
op.outputs = outputs
|
|
|
|
|
outs = defaultdict(list)
|
|
|
|
|
for k, vars in six.iteritems(outputs):
|
|
|
|
|
if isinstance(vars, framework.Variable):
|
|
|
|
|
vars.op = op
|
|
|
|
|
outs[k].append(vars._ivar)
|
|
|
|
|
elif isinstance(vars, list) or isinstance(vars, tuple):
|
|
|
|
|
for var in vars:
|
|
|
|
|
var.op = op
|
|
|
|
|
outs[k].append(var._ivar)
|
|
|
|
|
# TODO(hy): previous version will cause memory failed
|
|
|
|
|
op.inputs = inputs
|
|
|
|
|
inps = defaultdict(list)
|
|
|
|
|
for k, vars in six.iteritems(inputs):
|
|
|
|
|
if isinstance(vars, framework.Variable):
|
|
|
|
|
inps[k].append(vars._ivar)
|
|
|
|
|
elif isinstance(vars, list) or isinstance(vars, tuple):
|
|
|
|
|
for var in vars:
|
|
|
|
|
inps[k].append(var._ivar)
|
|
|
|
|
|
|
|
|
|
op.outputs = outputs
|
|
|
|
|
outs = defaultdict(list)
|
|
|
|
|
for k, vars in six.iteritems(outputs):
|
|
|
|
|
if isinstance(vars, framework.Variable):
|
|
|
|
|
outs[k].append(vars._ivar)
|
|
|
|
|
elif isinstance(vars, list) or isinstance(vars, tuple):
|
|
|
|
|
for var in vars:
|
|
|
|
|
outs[k].append(var._ivar)
|
|
|
|
|
|
|
|
|
|
# record op's trace id
|
|
|
|
|
op.iop._trace_id = self._trace_id
|
|
|
|
|