|
|
|
@ -14,7 +14,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
import six
|
|
|
|
import six
|
|
|
|
|
|
|
|
from six.moves import reduce
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections import defaultdict
|
|
|
|
from paddle.fluid import core
|
|
|
|
from paddle.fluid import core
|
|
|
|
@ -49,7 +51,16 @@ class Tracer(core.Tracer):
|
|
|
|
def trace_op(self, op, stop_gradient=False):
|
|
|
|
def trace_op(self, op, stop_gradient=False):
|
|
|
|
# record op's trace id
|
|
|
|
# record op's trace id
|
|
|
|
op.iop._trace_id = self._trace_id
|
|
|
|
op.iop._trace_id = self._trace_id
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
all_input_stop_grads = True
|
|
|
|
|
|
|
|
for vars in op.inputs.values():
|
|
|
|
|
|
|
|
for v in vars:
|
|
|
|
|
|
|
|
sys.stderr.write('%s %s\n' % (v.name, v.stop_gradient))
|
|
|
|
|
|
|
|
all_input_stop_grads &= v.stop_gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stop_gradient = False if not stop_gradient else True
|
|
|
|
|
|
|
|
stop_gradient = all_input_stop_grads | stop_gradient
|
|
|
|
|
|
|
|
"""
|
|
|
|
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
|
|
|
|
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
|
|
|
|
framework._current_expected_place(),
|
|
|
|
framework._current_expected_place(),
|
|
|
|
stop_gradient)
|
|
|
|
stop_gradient)
|
|
|
|
|