|
|
|
|
@ -449,6 +449,17 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
|
|
|
|
|
|
if loss.op is None:
|
|
|
|
|
# the loss is from a cloned program. Find loss op manually.
|
|
|
|
|
for op in reversed(loss.block.ops):
|
|
|
|
|
assert isinstance(op, framework.Operator)
|
|
|
|
|
if len(op.output_arg_names) == 1 and op.output_arg_names[
|
|
|
|
|
0] == loss.name:
|
|
|
|
|
loss.op = op
|
|
|
|
|
break
|
|
|
|
|
if loss.op is None:
|
|
|
|
|
raise ValueError("loss.op is None. Should not happend")
|
|
|
|
|
|
|
|
|
|
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
|
|
|
|
|
int(core.op_proto_and_checker_maker.OpRole.Forward) |
|
|
|
|
|
int(core.op_proto_and_checker_maker.OpRole.Loss))
|
|
|
|
|
|