|
|
|
@ -53,7 +53,8 @@ def _is_all_in_set_(cands, s):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_grad_suffix_(name):
|
|
|
|
|
return name[:name.find(core.grad_var_suffix())]
|
|
|
|
|
pos = name.find(core.grad_var_suffix())
|
|
|
|
|
return name[:pos] if pos != -1 else name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_grad_suffix_(name):
|
|
|
|
@ -139,7 +140,7 @@ def _append_backward_ops_(target,
|
|
|
|
|
to_insert = []
|
|
|
|
|
for idx, op_desc in enumerate(grad_op_descs):
|
|
|
|
|
for arg in op_desc.input_arg_names():
|
|
|
|
|
if arg in no_grad_set[block.idx]:
|
|
|
|
|
if core.grad_var_suffix() in arg and arg in no_grad_set[block.idx]:
|
|
|
|
|
to_insert.append((arg, idx))
|
|
|
|
|
for ele in reversed(to_insert):
|
|
|
|
|
arg = ele[0]
|
|
|
|
|