|
|
|
@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
|
|
|
|
|
op_desc.set_input(para, args)
|
|
|
|
|
for para, args in outputs.iteritems():
|
|
|
|
|
op_desc.set_output(para, args)
|
|
|
|
|
|
|
|
|
|
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
|
|
|
|
|
|
|
|
if op_role_attr_name not in attrs:
|
|
|
|
|
attrs[
|
|
|
|
|
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
|
|
|
|
|
for name, val in attrs.iteritems():
|
|
|
|
|
if isinstance(val, framework.Block):
|
|
|
|
|
op_desc.set_block_attr(name, val.desc)
|
|
|
|
@ -141,7 +147,7 @@ def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
else:
|
|
|
|
|
if len(renamed_vars[var_name]) == 1:
|
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
var_rename_count[var_name] += 1
|
|
|
|
|
# rename original var_name
|
|
|
|
|
renamed_vars[var_name][0] = new_name
|
|
|
|
@ -149,7 +155,7 @@ def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
_rename_arg_(pending_sum_ops, var_name, new_name)
|
|
|
|
|
|
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
var_rename_count[var_name] += 1
|
|
|
|
|
op_desc.rename_output(var_name, new_name)
|
|
|
|
|
renamed_vars[var_name].append(new_name)
|
|
|
|
@ -335,9 +341,12 @@ def _append_backward_ops_(block,
|
|
|
|
|
no_grad_dict[block.idx])
|
|
|
|
|
|
|
|
|
|
# append op_desc in grad_op_descs to target_block
|
|
|
|
|
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
|
|
|
backward = core.op_proto_and_checker_maker.OpRole.Backward
|
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
new_op_desc = target_block.desc.append_op()
|
|
|
|
|
new_op_desc.copy_from(op_desc)
|
|
|
|
|
new_op_desc.set_attr(op_role_attr_name, backward)
|
|
|
|
|
grad_to_var["__current_op_desc__"] = new_op_desc
|
|
|
|
|
if callbacks is not None:
|
|
|
|
|
assert (isinstance(callbacks, list))
|
|
|
|
@ -439,6 +448,22 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
|
|
|
|
|
|
if loss.op is None:
|
|
|
|
|
# the loss is from a cloned program. Find loss op manually.
|
|
|
|
|
for op in reversed(loss.block.ops):
|
|
|
|
|
assert isinstance(op, framework.Operator)
|
|
|
|
|
if len(op.output_arg_names) == 1 and op.output_arg_names[
|
|
|
|
|
0] == loss.name:
|
|
|
|
|
loss.op = op
|
|
|
|
|
break
|
|
|
|
|
if loss.op is None:
|
|
|
|
|
raise ValueError("loss.op is None. Should not happend")
|
|
|
|
|
|
|
|
|
|
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
|
|
|
|
|
int(core.op_proto_and_checker_maker.OpRole.Forward) |
|
|
|
|
|
int(core.op_proto_and_checker_maker.OpRole.Loss))
|
|
|
|
|
|
|
|
|
|
if callbacks is not None:
|
|
|
|
|
isinstance(callbacks, list)
|
|
|
|
|
|
|
|
|
@ -456,12 +481,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
current_block_idx = program.current_block_idx
|
|
|
|
|
grad_to_var = dict()
|
|
|
|
|
|
|
|
|
|
op_desc = _create_op_desc_("fill_constant", {}, {
|
|
|
|
|
"Out": [_append_grad_suffix_(loss.name)]
|
|
|
|
|
}, {"shape": [1],
|
|
|
|
|
"value": 1.0,
|
|
|
|
|
"dtype": loss.dtype,
|
|
|
|
|
"force_cpu": False})
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|
"fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
|
|
|
|
|
"shape": [1],
|
|
|
|
|
"value": 1.0,
|
|
|
|
|
"dtype": loss.dtype,
|
|
|
|
|
"force_cpu": False,
|
|
|
|
|
core.op_proto_and_checker_maker.kOpRoleAttrName():
|
|
|
|
|
int(core.op_proto_and_checker_maker.OpRole.Backward) |
|
|
|
|
|
int(core.op_proto_and_checker_maker.OpRole.Loss),
|
|
|
|
|
})
|
|
|
|
|
root_block.desc.append_op().copy_from(op_desc)
|
|
|
|
|
|
|
|
|
|
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
|
|
|
|
@ -505,6 +534,24 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
params_and_grads.append((param_var, grad_var))
|
|
|
|
|
else:
|
|
|
|
|
params_and_grads.append((param_var, None))
|
|
|
|
|
|
|
|
|
|
op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
|
|
|
|
|
for p, g in params_and_grads:
|
|
|
|
|
if g is None:
|
|
|
|
|
continue
|
|
|
|
|
for op in reversed(program.global_block().ops):
|
|
|
|
|
assert isinstance(op, framework.Operator)
|
|
|
|
|
if g.name in op.output_arg_names:
|
|
|
|
|
g.op = op
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if g.op is None:
|
|
|
|
|
raise ValueError("Unexpected branch")
|
|
|
|
|
attr_val = [p.name, g.name]
|
|
|
|
|
if g.op.has_attr(op_role_var_attr_name):
|
|
|
|
|
attr_val.extend(g.op.attr(op_role_var_attr_name))
|
|
|
|
|
g.op.set_attr(op_role_var_attr_name, attr_val)
|
|
|
|
|
|
|
|
|
|
return params_and_grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|