|
|
|
@ -32,12 +32,27 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
|
|
|
|
|
return op_desc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def backward_impl(target,
|
|
|
|
|
block,
|
|
|
|
|
target_block,
|
|
|
|
|
no_grad_set,
|
|
|
|
|
grad_info_map,
|
|
|
|
|
callback=None):
|
|
|
|
|
def _is_all_in_set_(cands, s):
|
|
|
|
|
for c in cands:
|
|
|
|
|
if not c in s:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_grad_suffix_(name):
|
|
|
|
|
return name[:name.find(core.grad_var_suffix())]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_grad_suffix_(name):
|
|
|
|
|
return name + core.grad_var_suffix()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _backward_impl_(target,
|
|
|
|
|
block,
|
|
|
|
|
target_block,
|
|
|
|
|
no_grad_set,
|
|
|
|
|
grad_info_map,
|
|
|
|
|
callback=None):
|
|
|
|
|
grad_op_descs = []
|
|
|
|
|
grad_to_var = dict()
|
|
|
|
|
program = block.program
|
|
|
|
@ -47,8 +62,8 @@ def backward_impl(target,
|
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
|
|
sub_block = program.block(sub_block_idx)
|
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
|
|
|
|
|
backward_impl(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
|
grad_info_map, callback)
|
|
|
|
|
_backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
|
grad_info_map, callback)
|
|
|
|
|
grad_sub_block_list.append(grad_sub_block)
|
|
|
|
|
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
|
|
|
|
|
each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
|
|
|
|
@ -61,14 +76,14 @@ def backward_impl(target,
|
|
|
|
|
pending_sum_ops = []
|
|
|
|
|
var_rename_count = collections.defaultdict(int)
|
|
|
|
|
var_inputs = collections.defaultdict(list)
|
|
|
|
|
for pos, op_desc in enumerate(grad_op_descs):
|
|
|
|
|
for idx, op_desc in enumerate(grad_op_descs):
|
|
|
|
|
for var_name in op_desc.input_arg_names():
|
|
|
|
|
if len(var_inputs[var_name]) > 1:
|
|
|
|
|
pending_sum_ops.append((_create_op_desc_(
|
|
|
|
|
op_type="sum_op",
|
|
|
|
|
inputs=var_inputs[var_name],
|
|
|
|
|
outputs=[var_name],
|
|
|
|
|
attrs={}), pos))
|
|
|
|
|
attrs={}), idx))
|
|
|
|
|
var_inputs[var_name] = [var_name]
|
|
|
|
|
for var_name in op_desc.output_arg_names():
|
|
|
|
|
if len(var_inputs[var_name]) == 0:
|
|
|
|
@ -81,7 +96,7 @@ def backward_impl(target,
|
|
|
|
|
var_rename_count[var_name] = var_rename_count[var_name] + 1
|
|
|
|
|
# rename original var_name
|
|
|
|
|
var_inputs[var_name][0] = new_name
|
|
|
|
|
_rename_arg_(grad_op_descs, var_name, new_name, 0, pos)
|
|
|
|
|
_rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
|
|
|
|
|
_rename_arg_(pending_sum_ops, var_name, new_name)
|
|
|
|
|
|
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
@ -96,18 +111,31 @@ def backward_impl(target,
|
|
|
|
|
inputs={"X": inputs},
|
|
|
|
|
outputs={"Out": var_name},
|
|
|
|
|
attrs={}), len(grad_op_descs)))
|
|
|
|
|
# TODO: remove op in no grad set
|
|
|
|
|
|
|
|
|
|
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
|
|
|
|
|
for p in reversed(pending_sum_ops):
|
|
|
|
|
grad_op_descs.insert(p[1], p[0])
|
|
|
|
|
# Remove ops whose outputs are all in no_grad_set
|
|
|
|
|
grad_op_descs = filter(
|
|
|
|
|
lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
|
|
|
|
|
grad_op_descs)
|
|
|
|
|
# Insert fill_zeros_like_op
|
|
|
|
|
to_insert = []
|
|
|
|
|
for idx, op_desc in enumerate(grad_op_descs):
|
|
|
|
|
for arg in op_desc.input_arg_names():
|
|
|
|
|
if arg in no_grad_set[block.idx]:
|
|
|
|
|
to_insert.append((arg, idx))
|
|
|
|
|
for ele in reversed(to_insert):
|
|
|
|
|
arg = ele[0]
|
|
|
|
|
fill_zeros_like_op = _create_op_desc_(
|
|
|
|
|
"fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
|
|
|
|
|
{})
|
|
|
|
|
grad_op_descs.insert(ele[1], fill_zeros_like_op)
|
|
|
|
|
# create new gradient variables in the target block desc
|
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
for grad_var_name in op_desc.output_arg_names():
|
|
|
|
|
grad_var_name = grad_var_name.encode("ascii")
|
|
|
|
|
if target_block.desc.has_var(
|
|
|
|
|
grad_var_name) or grad_var_name == core.get_empty_var_name(
|
|
|
|
|
):
|
|
|
|
|
grad_var_name) or grad_var_name == core.empty_var_name():
|
|
|
|
|
continue
|
|
|
|
|
target_block.desc.var(grad_var_name)
|
|
|
|
|
if not grad_to_var.has_key(grad_var_name):
|
|
|
|
@ -115,8 +143,8 @@ def backward_impl(target,
|
|
|
|
|
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
|
|
|
|
|
target_block)
|
|
|
|
|
if target_block.idx == 0:
|
|
|
|
|
grad_target_name = (target.name + "@GRAD")
|
|
|
|
|
target_block.desc.var(grad_target_name)
|
|
|
|
|
grad_target_name = _append_grad_suffix_(target.name)
|
|
|
|
|
target_block.desc.var(grad_target_name.encode("ascii"))
|
|
|
|
|
grad_op_descs.insert(
|
|
|
|
|
0,
|
|
|
|
|
_create_op_desc_(
|
|
|
|
@ -134,7 +162,6 @@ def backward_impl(target,
|
|
|
|
|
op_desc.infer_shape(target_block.desc)
|
|
|
|
|
target_block.desc.append_allocated_op(op_desc)
|
|
|
|
|
|
|
|
|
|
pdb.set_trace()
|
|
|
|
|
target_block.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -165,14 +192,14 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
for var in block.vars.itervalues():
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
if var.stop_gradient:
|
|
|
|
|
block_no_grad_set.add(var.name)
|
|
|
|
|
block_no_grad_set.add(_append_grad_suffix_(var.name))
|
|
|
|
|
no_grad_set[block.idx] = block_no_grad_set
|
|
|
|
|
|
|
|
|
|
grad_info_map = dict()
|
|
|
|
|
root_block = loss.block.program.block(0)
|
|
|
|
|
pdb.set_trace()
|
|
|
|
|
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
|
|
|
|
|
pdb.set_trace()
|
|
|
|
|
|
|
|
|
|
_backward_impl_(loss, root_block, root_block, no_grad_set, grad_info_map)
|
|
|
|
|
|
|
|
|
|
if parameter_list is not None:
|
|
|
|
|
parameters = parameter_list
|
|
|
|
|
else:
|
|
|
|
|