|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from paddle.v2.fluid import framework as framework
|
|
|
|
|
from . import core
|
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
__all__ = ['append_backward_ops']
|
|
|
|
|
|
|
|
|
@ -20,6 +21,20 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
|
|
|
|
|
no_grad_set[block.idx],
|
|
|
|
|
grad_to_var, grad_sub_block_list)
|
|
|
|
|
grad_op_descs.append(grad_op_desc)
|
|
|
|
|
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
|
|
|
|
|
# flatten grad_op_descs
|
|
|
|
|
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
|
|
|
|
|
|
|
|
|
|
output_vars = collections.defaultdict(list)
|
|
|
|
|
for pos, op_desc in enumerate(grad_op_descs):
|
|
|
|
|
for var_name in op_desc.output_arg_names():
|
|
|
|
|
output_vars[var_name].append(pos)
|
|
|
|
|
for var_name, poses in output_vars.iteritems():
|
|
|
|
|
if len(poses) == 1:
|
|
|
|
|
continue
|
|
|
|
|
renamed_list = []
|
|
|
|
|
for pos in reversed(sorted(poses)):
|
|
|
|
|
new_name = var_name + "@RENAMED@" + len(renamed_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|