|
|
|
@ -269,7 +269,7 @@ def _append_backward_ops_(block,
|
|
|
|
|
target_block,
|
|
|
|
|
no_grad_dict,
|
|
|
|
|
grad_to_var,
|
|
|
|
|
callback=None):
|
|
|
|
|
callbacks=None):
|
|
|
|
|
"""
|
|
|
|
|
Create all grad ops, and insert them into given block
|
|
|
|
|
|
|
|
|
@ -285,14 +285,13 @@ def _append_backward_ops_(block,
|
|
|
|
|
val(str): corresponding forward variable name
|
|
|
|
|
callback(callable object): a callable object used to decorate new generated grad ops
|
|
|
|
|
"""
|
|
|
|
|
if callback is None:
|
|
|
|
|
|
|
|
|
|
def empty_callback(block, context):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
callback = empty_callback
|
|
|
|
|
elif not hasattr(callback, '__call__'):
|
|
|
|
|
raise ValueError("'callback' must be a callable object.")
|
|
|
|
|
if callbacks is None:
|
|
|
|
|
callbacks = []
|
|
|
|
|
else:
|
|
|
|
|
assert (isinstance(callbacks, list))
|
|
|
|
|
for cb in callbacks:
|
|
|
|
|
if not hasattr(cb, '__call__'):
|
|
|
|
|
raise ValueError("'callback' must be a callable object.")
|
|
|
|
|
|
|
|
|
|
# grad_op_descs holds created grad_op, and will be appended to target_block
|
|
|
|
|
grad_op_descs = []
|
|
|
|
@ -303,9 +302,12 @@ def _append_backward_ops_(block,
|
|
|
|
|
if op.has_attr("sub_block"):
|
|
|
|
|
sub_block = program.block(op.block_attr("sub_block"))
|
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
|
|
|
|
|
if callbacks is None:
|
|
|
|
|
callbacks = [_callback_lookup_(op)]
|
|
|
|
|
else:
|
|
|
|
|
callbacks.append(_callback_lookup_(op))
|
|
|
|
|
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
|
|
|
|
|
no_grad_dict, grad_to_var,
|
|
|
|
|
_callback_lookup_(op))
|
|
|
|
|
no_grad_dict, grad_to_var, callbacks)
|
|
|
|
|
grad_sub_block_list.append(grad_sub_block.desc)
|
|
|
|
|
|
|
|
|
|
# Getting op's corresponding grad_op
|
|
|
|
@ -325,7 +327,8 @@ def _append_backward_ops_(block,
|
|
|
|
|
new_op_desc = target_block.desc.append_op()
|
|
|
|
|
new_op_desc.copy_from(op_desc)
|
|
|
|
|
grad_to_var["__current_op_desc__"] = new_op_desc
|
|
|
|
|
callback(block=target_block, context=grad_to_var)
|
|
|
|
|
for cb in callbacks:
|
|
|
|
|
cb(block=target_block, context=grad_to_var)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|
|