|
|
|
@ -57,6 +57,8 @@ def _all_in_set_(cands, s):
|
|
|
|
|
"""
|
|
|
|
|
Test if all elements of 'cands' are in set 's'
|
|
|
|
|
"""
|
|
|
|
|
if len(cands) == 0:
|
|
|
|
|
return False
|
|
|
|
|
for c in cands:
|
|
|
|
|
if not c in s:
|
|
|
|
|
return False
|
|
|
|
@ -138,10 +140,20 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
|
|
|
|
|
1. all outputs of the grad op are in 'no_grad_set'
|
|
|
|
|
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _op_can_be_removed_(op_desc, no_grad_set):
|
|
|
|
|
if _all_in_set_(op_desc.output_arg_names(), no_grad_set):
|
|
|
|
|
return True
|
|
|
|
|
if _all_in_set_(
|
|
|
|
|
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
|
|
|
|
|
op_desc.input_arg_names()), no_grad_set):
|
|
|
|
|
no_grad_set.union(op_desc.output_arg_names())
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Remove ops whose outputs are all in no_grad_dict
|
|
|
|
|
op_descs = filter(
|
|
|
|
|
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
|
|
|
|
|
op_descs)
|
|
|
|
|
lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
|
|
|
|
|
# Insert fill_zeros_like_op
|
|
|
|
|
to_insert = []
|
|
|
|
|
for idx, op_desc in enumerate(op_descs):
|
|
|
|
|