shanyi15-patch-3
yuyang18 8 years ago
parent 3b04f0099c
commit 23e19e2e42

@ -449,6 +449,17 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
"""
assert isinstance(loss, framework.Variable)
if loss.op is None:
# the loss is from a cloned program. Find loss op manually.
for op in reversed(loss.block.ops):
assert isinstance(op, framework.Operator)
if len(op.output_arg_names) == 1 and op.output_arg_names[
0] == loss.name:
loss.op = op
break
if loss.op is None:
raise ValueError("loss.op is None. Should not happend")
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
int(core.op_proto_and_checker_maker.OpRole.Forward) |
int(core.op_proto_and_checker_maker.OpRole.Loss))

Loading…
Cancel
Save