|
|
@ -190,8 +190,15 @@ def _append_backward_ops_(target,
|
|
|
|
val(str): corresponding forward variable name
|
|
|
|
val(str): corresponding forward variable name
|
|
|
|
callback(callable object): a callable object used to decorate new generated grad ops
|
|
|
|
callback(callable object): a callable object used to decorate new generated grad ops
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if callback is not None and not hasattr(callback, '__call__'):
|
|
|
|
if callback is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def empty_callback(block):
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
callback = empty_callback
|
|
|
|
|
|
|
|
elif not hasattr(callback, '__call__'):
|
|
|
|
raise ValueError("'callback' must be a callable object.")
|
|
|
|
raise ValueError("'callback' must be a callable object.")
|
|
|
|
|
|
|
|
|
|
|
|
# grad_op_descs holds created grad_op, and will be appended to target_block
|
|
|
|
# grad_op_descs holds created grad_op, and will be appended to target_block
|
|
|
|
grad_op_descs = []
|
|
|
|
grad_op_descs = []
|
|
|
|
program = block.program
|
|
|
|
program = block.program
|
|
|
@ -208,8 +215,6 @@ def _append_backward_ops_(target,
|
|
|
|
# Getting op's corresponding grad_op
|
|
|
|
# Getting op's corresponding grad_op
|
|
|
|
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
|
|
|
|
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
|
|
|
|
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
|
|
|
|
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
|
|
|
|
if callback is not None:
|
|
|
|
|
|
|
|
grad_op_desc = callback(grad_op_desc)
|
|
|
|
|
|
|
|
grad_op_descs.extend(grad_op_desc)
|
|
|
|
grad_op_descs.extend(grad_op_desc)
|
|
|
|
grad_to_var.update(op_grad_to_var)
|
|
|
|
grad_to_var.update(op_grad_to_var)
|
|
|
|
|
|
|
|
|
|
|
@ -230,6 +235,7 @@ def _append_backward_ops_(target,
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
new_op_desc = target_block.desc.append_op()
|
|
|
|
new_op_desc = target_block.desc.append_op()
|
|
|
|
new_op_desc.copy_from(op_desc)
|
|
|
|
new_op_desc.copy_from(op_desc)
|
|
|
|
|
|
|
|
callback(block=target_block, context=grad_to_var)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|