|
|
|
@ -5,14 +5,17 @@ import collections
|
|
|
|
|
__all__ = ['append_backward']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
|
|
|
|
|
end_idx=None):
|
|
|
|
|
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
|
|
|
|
|
"""
|
|
|
|
|
Traverse all ops in op_descs[begin_idx : end_idx],
|
|
|
|
|
if any op has inputs/outputs named "old_name", rename it as 'new_name'
|
|
|
|
|
"""
|
|
|
|
|
if begin_idx is None:
|
|
|
|
|
begin_idx = 0
|
|
|
|
|
if end_idx is None:
|
|
|
|
|
end_idx = len(op_desc_list)
|
|
|
|
|
end_idx = len(op_descs)
|
|
|
|
|
for i in range(begin_idx, end_idx):
|
|
|
|
|
op_desc = op_desc_list[i]
|
|
|
|
|
op_desc = op_descs[i]
|
|
|
|
|
if isinstance(op_desc, tuple):
|
|
|
|
|
op_desc = op_desc[0]
|
|
|
|
|
op_desc.rename_input(old_name, new_name)
|
|
|
|
@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_op_desc_(op_type, inputs, outputs, attrs):
|
|
|
|
|
"""
|
|
|
|
|
Create a C++ OpDesc object with specified inputs, outputs and attributes.
|
|
|
|
|
"""
|
|
|
|
|
op_desc = core.OpDesc()
|
|
|
|
|
op_desc.set_type(op_type)
|
|
|
|
|
for para, args in inputs.iteritems():
|
|
|
|
@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
|
|
|
|
|
return op_desc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _infer_var_data_type_(var_name, block):
|
|
|
|
|
grad_var = block.desc.find_var(var_name.encode("ascii"))
|
|
|
|
|
fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
|
|
|
|
|
def _infer_var_data_type_(grad_var_name, block):
|
|
|
|
|
"""
|
|
|
|
|
Infer the data type of given grad variable
|
|
|
|
|
"""
|
|
|
|
|
grad_var = block.desc.find_var(grad_var_name.encode("ascii"))
|
|
|
|
|
fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii"))
|
|
|
|
|
if block.desc.has_var_recursive(fwd_name):
|
|
|
|
|
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
|
|
|
|
|
grad_var.set_dtype(fwd_var.dtype())
|
|
|
|
@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_in_set_(cands, s):
|
|
|
|
|
"""
|
|
|
|
|
Test if all elements of 'cands' are in set 's'
|
|
|
|
|
"""
|
|
|
|
|
for c in cands:
|
|
|
|
|
if not c in s:
|
|
|
|
|
return False
|
|
|
|
@ -52,18 +64,29 @@ def _all_in_set_(cands, s):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_grad_suffix_(name):
|
|
|
|
|
"""
|
|
|
|
|
Strip the grad suffix from the given varibale name
|
|
|
|
|
e.g. x@GRAD ==> x
|
|
|
|
|
y@GRAD@RENAME@1 ==> y
|
|
|
|
|
"""
|
|
|
|
|
pos = name.find(core.grad_var_suffix())
|
|
|
|
|
return name[:pos] if pos != -1 else name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_grad_suffix_(name):
|
|
|
|
|
"""
|
|
|
|
|
Append grad suffix to the given variable name
|
|
|
|
|
e.g. x ==> x@GRAD
|
|
|
|
|
"""
|
|
|
|
|
return name + core.grad_var_suffix()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
# In backward part, an variable my be the output of more than one ops.
|
|
|
|
|
# In this case, the variable should be the accumulation of all the outputs.
|
|
|
|
|
# We adopt adding `sum_op`s to implement the accumulate.
|
|
|
|
|
"""
|
|
|
|
|
In backward part, an variable may be the output of more than one ops.
|
|
|
|
|
In this case, the variable should be the accumulation of all the outputs.
|
|
|
|
|
`sum_op`s are added to implement the accumulate.
|
|
|
|
|
"""
|
|
|
|
|
pending_sum_ops = []
|
|
|
|
|
var_rename_count = collections.defaultdict(int)
|
|
|
|
|
renamed_vars = collections.defaultdict(list)
|
|
|
|
@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_no_grad_branch_(op_descs, no_grad_set):
|
|
|
|
|
"""
|
|
|
|
|
Remove unnecessary grad ops
|
|
|
|
|
A grad op can be removed in two cases:
|
|
|
|
|
1. all outputs of the grad op are in 'no_grad_set'
|
|
|
|
|
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
|
|
|
|
|
"""
|
|
|
|
|
# Remove ops whose outputs are all in no_grad_dict
|
|
|
|
|
op_descs = filter(
|
|
|
|
|
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
|
|
|
|
@ -133,6 +162,20 @@ def _append_backward_ops_(target,
|
|
|
|
|
no_grad_dict,
|
|
|
|
|
grad_to_var,
|
|
|
|
|
callback=None):
|
|
|
|
|
"""
|
|
|
|
|
Create all grad ops, and insert them into given block
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
target(Variable): the target variable of forward pass
|
|
|
|
|
block(Block): the block where forward ops are
|
|
|
|
|
target_block(Block): the block which is going to hold new generated grad ops
|
|
|
|
|
no_grad_dict(dict):
|
|
|
|
|
key(int) block index
|
|
|
|
|
val(set) a set of varibale names. These varibales have no gradient
|
|
|
|
|
grad_to_var(dict)(output argument):
|
|
|
|
|
key(str): grad variable name
|
|
|
|
|
val(str): corresponding forward variable name
|
|
|
|
|
"""
|
|
|
|
|
grad_op_descs = []
|
|
|
|
|
program = block.program
|
|
|
|
|
for op in reversed(block.ops):
|
|
|
|
@ -170,6 +213,20 @@ def _append_backward_ops_(target,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|
|
"""
|
|
|
|
|
Create new variables required by backward pass.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
block(Block): the block where new variables will be created
|
|
|
|
|
start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
|
|
|
|
|
grad_to_var(dict):
|
|
|
|
|
key(str): grad variable name
|
|
|
|
|
val(str): corresponding forward variable name
|
|
|
|
|
In most cases, this dict is generated by _append_backward_ops_()
|
|
|
|
|
grad_info_map(dict)(output argument):
|
|
|
|
|
key(str): forward variable name
|
|
|
|
|
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
|
|
|
|
|
"""
|
|
|
|
|
for op_idx in range(start_op_idx, block.desc.op_size()):
|
|
|
|
|
op_desc = block.desc.op(op_idx)
|
|
|
|
|
if op_desc.has_attr("sub_block"):
|
|
|
|
|