|
|
|
@ -71,7 +71,9 @@ def _backward_impl_(target,
|
|
|
|
|
if each_op.has_attr("sub_block"):
|
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
|
|
sub_block = program.block(sub_block_idx)
|
|
|
|
|
original_block_idx = program.current_block_idx
|
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
|
|
|
|
|
program.current_block_idx = original_block_idx
|
|
|
|
|
_backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
|
grad_info_map, callback)
|
|
|
|
|
grad_sub_block_list.append(grad_sub_block.desc)
|
|
|
|
@ -120,9 +122,9 @@ def _backward_impl_(target,
|
|
|
|
|
pending_sum_ops.append((_create_op_desc_(
|
|
|
|
|
op_type="sum",
|
|
|
|
|
inputs={"X": inputs},
|
|
|
|
|
outputs={"Out": var_name},
|
|
|
|
|
outputs={"Out": [var_name]},
|
|
|
|
|
attrs={}), len(grad_op_descs)))
|
|
|
|
|
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
|
|
|
|
|
# sum_op descs are sorted according to their insert position
|
|
|
|
|
for p in reversed(pending_sum_ops):
|
|
|
|
|
grad_op_descs.insert(p[1], p[0])
|
|
|
|
|
# Remove ops whose outputs are all in no_grad_set
|
|
|
|
|