|
|
@ -6,7 +6,8 @@ import pdb
|
|
|
|
__all__ = ['append_backward_ops']
|
|
|
|
__all__ = ['append_backward_ops']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
|
|
|
|
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
|
|
|
|
|
|
|
|
end_idx=None):
|
|
|
|
if begin_idx is None:
|
|
|
|
if begin_idx is None:
|
|
|
|
begin_idx = 0
|
|
|
|
begin_idx = 0
|
|
|
|
if end_idx is None:
|
|
|
|
if end_idx is None:
|
|
|
@ -16,6 +17,21 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
|
|
|
|
op_desc_list[i].rename_output(old_name, new_name)
|
|
|
|
op_desc_list[i].rename_output(old_name, new_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_op_desc_(op_type, inputs, outputs, attrs):
|
|
|
|
|
|
|
|
op_desc = core.OpDesc()
|
|
|
|
|
|
|
|
op_desc.set_type(op_type)
|
|
|
|
|
|
|
|
for para, args in inputs.iteritems():
|
|
|
|
|
|
|
|
op_desc.set_input(para, args)
|
|
|
|
|
|
|
|
for para, args in outputs.iteritems():
|
|
|
|
|
|
|
|
op_desc.set_output(para, args)
|
|
|
|
|
|
|
|
for name, val in attrs.iteritems():
|
|
|
|
|
|
|
|
if isinstance(val, framework.Block):
|
|
|
|
|
|
|
|
op_desc.set_block_attr(name, val.desc)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
op_desc.set_attr(name, val)
|
|
|
|
|
|
|
|
return op_desc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def backward_impl(target,
|
|
|
|
def backward_impl(target,
|
|
|
|
block,
|
|
|
|
block,
|
|
|
|
target_block,
|
|
|
|
target_block,
|
|
|
@ -23,9 +39,9 @@ def backward_impl(target,
|
|
|
|
grad_info_map,
|
|
|
|
grad_info_map,
|
|
|
|
callback=None):
|
|
|
|
callback=None):
|
|
|
|
grad_op_descs = []
|
|
|
|
grad_op_descs = []
|
|
|
|
grad_to_var = {}
|
|
|
|
grad_to_var = dict()
|
|
|
|
program = block.program
|
|
|
|
program = block.program
|
|
|
|
for each_op in block.ops:
|
|
|
|
for each_op in reversed(block.ops):
|
|
|
|
grad_sub_block_list = []
|
|
|
|
grad_sub_block_list = []
|
|
|
|
if each_op.has_attr("sub_block"):
|
|
|
|
if each_op.has_attr("sub_block"):
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
@ -34,10 +50,10 @@ def backward_impl(target,
|
|
|
|
backward_impl(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
backward_impl(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
grad_info_map, callback)
|
|
|
|
grad_info_map, callback)
|
|
|
|
grad_sub_block_list.append(grad_sub_block)
|
|
|
|
grad_sub_block_list.append(grad_sub_block)
|
|
|
|
grad_op_desc = core.get_grad_op_desc(each_op.desc,
|
|
|
|
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
|
|
|
|
no_grad_set[block.idx],
|
|
|
|
each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
|
|
|
|
grad_to_var, grad_sub_block_list)
|
|
|
|
|
|
|
|
grad_op_descs.append(grad_op_desc)
|
|
|
|
grad_op_descs.append(grad_op_desc)
|
|
|
|
|
|
|
|
grad_to_var = dict(grad_to_var, **op_grad_to_var)
|
|
|
|
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
|
|
|
|
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
|
|
|
|
# flatten grad_op_descs
|
|
|
|
# flatten grad_op_descs
|
|
|
|
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
|
|
|
|
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
|
|
|
@ -48,11 +64,10 @@ def backward_impl(target,
|
|
|
|
for pos, op_desc in enumerate(grad_op_descs):
|
|
|
|
for pos, op_desc in enumerate(grad_op_descs):
|
|
|
|
for var_name in op_desc.input_arg_names():
|
|
|
|
for var_name in op_desc.input_arg_names():
|
|
|
|
if len(var_inputs[var_name]) > 1:
|
|
|
|
if len(var_inputs[var_name]) > 1:
|
|
|
|
pdb.set_trace()
|
|
|
|
pending_sum_ops.append((_create_op_desc_(
|
|
|
|
pending_sum_ops.append((core.OpDesc(
|
|
|
|
op_type="sum_op",
|
|
|
|
type="sum_op",
|
|
|
|
|
|
|
|
inputs=var_inputs[var_name],
|
|
|
|
inputs=var_inputs[var_name],
|
|
|
|
output=[var_name],
|
|
|
|
outputs=[var_name],
|
|
|
|
attrs={}), pos))
|
|
|
|
attrs={}), pos))
|
|
|
|
var_inputs[var_name] = [var_name]
|
|
|
|
var_inputs[var_name] = [var_name]
|
|
|
|
for var_name in op_desc.output_arg_names():
|
|
|
|
for var_name in op_desc.output_arg_names():
|
|
|
@ -66,8 +81,8 @@ def backward_impl(target,
|
|
|
|
var_rename_count[var_name] = var_rename_count[var_name] + 1
|
|
|
|
var_rename_count[var_name] = var_rename_count[var_name] + 1
|
|
|
|
# rename original var_name
|
|
|
|
# rename original var_name
|
|
|
|
var_inputs[var_name][0] = new_name
|
|
|
|
var_inputs[var_name][0] = new_name
|
|
|
|
rename_arg(grad_op_descs, var_name, new_name, 0, pos)
|
|
|
|
_rename_arg_(grad_op_descs, var_name, new_name, 0, pos)
|
|
|
|
rename_arg(pending_sum_ops, var_name, new_name)
|
|
|
|
_rename_arg_(pending_sum_ops, var_name, new_name)
|
|
|
|
|
|
|
|
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
str(var_rename_count[var_name])
|
|
|
@ -76,10 +91,11 @@ def backward_impl(target,
|
|
|
|
var_inputs[var_name].append(new_name)
|
|
|
|
var_inputs[var_name].append(new_name)
|
|
|
|
for var_name, inputs in var_inputs.iteritems():
|
|
|
|
for var_name, inputs in var_inputs.iteritems():
|
|
|
|
if len(inputs) > 1:
|
|
|
|
if len(inputs) > 1:
|
|
|
|
pdb.set_trace()
|
|
|
|
pending_sum_ops.append((_create_op_desc_(
|
|
|
|
pending_sum_ops.append((core.OpDesc("sum_op", {"X": inputs},
|
|
|
|
op_type="sum_op",
|
|
|
|
{"Out": var_name}, {}),
|
|
|
|
inputs={"X": inputs},
|
|
|
|
len(grad_op_descs)))
|
|
|
|
outputs={"Out": var_name},
|
|
|
|
|
|
|
|
attrs={}), len(grad_op_descs)))
|
|
|
|
# TODO: remove op in no grad set
|
|
|
|
# TODO: remove op in no grad set
|
|
|
|
|
|
|
|
|
|
|
|
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
|
|
|
|
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
|
|
|
@ -103,15 +119,22 @@ def backward_impl(target,
|
|
|
|
target_block.desc.var(grad_target_name)
|
|
|
|
target_block.desc.var(grad_target_name)
|
|
|
|
grad_op_descs.insert(
|
|
|
|
grad_op_descs.insert(
|
|
|
|
0,
|
|
|
|
0,
|
|
|
|
core.OpDesc(u"fill_constant", {}, {
|
|
|
|
_create_op_desc_(
|
|
|
|
u"Out": [unicode(grad_target_name, "ascii")]
|
|
|
|
op_type="fill_constant",
|
|
|
|
}, {u"shape": (1),
|
|
|
|
inputs={},
|
|
|
|
u"value": 1.0,
|
|
|
|
outputs={"Out": [grad_target_name]},
|
|
|
|
u"dtype": core.DataType.FP32}))
|
|
|
|
attrs={
|
|
|
|
|
|
|
|
"shape": [1],
|
|
|
|
|
|
|
|
"value": 1.0,
|
|
|
|
|
|
|
|
"dtype": core.DataType.FP32
|
|
|
|
|
|
|
|
}))
|
|
|
|
# insert backward operators to target_block
|
|
|
|
# insert backward operators to target_block
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
|
|
|
op_desc.infer_var_type(target_block.desc)
|
|
|
|
|
|
|
|
op_desc.infer_shape(target_block.desc)
|
|
|
|
target_block.desc.append_allocated_op(op_desc)
|
|
|
|
target_block.desc.append_allocated_op(op_desc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pdb.set_trace()
|
|
|
|
target_block.sync_with_cpp()
|
|
|
|
target_block.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -147,6 +170,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
|
|
|
|
|
|
|
grad_info_map = dict()
|
|
|
|
grad_info_map = dict()
|
|
|
|
root_block = loss.block.program.block(0)
|
|
|
|
root_block = loss.block.program.block(0)
|
|
|
|
|
|
|
|
pdb.set_trace()
|
|
|
|
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
|
|
|
|
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
|
|
|
|
pdb.set_trace()
|
|
|
|
pdb.set_trace()
|
|
|
|
if parameter_list is not None:
|
|
|
|
if parameter_list is not None:
|
|
|
@ -159,7 +183,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
if param not in grad_info_map:
|
|
|
|
if param not in grad_info_map:
|
|
|
|
raise ValueError("param %s is not in map" % param)
|
|
|
|
raise ValueError("param %s is not in map" % param)
|
|
|
|
grad_info = grad_info_map[param]
|
|
|
|
grad_info = grad_info_map[param]
|
|
|
|
grad_block = loss.block.program.block(grad_info[1])
|
|
|
|
grad_block = grad_info[1]
|
|
|
|
if not grad_block.has_var(grad_info[0]):
|
|
|
|
if not grad_block.has_var(grad_info[0]):
|
|
|
|
raise ValueError("grad block[{0}] did not have grad var {1}".format(
|
|
|
|
raise ValueError("grad block[{0}] did not have grad var {1}".format(
|
|
|
|
grad_info[1], grad_info[0]))
|
|
|
|
grad_info[1], grad_info[0]))
|
|
|
|