|
|
|
@ -19,8 +19,20 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
:rtype: list[Variable]
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
|
param_grad_map = loss.block.program.append_backward(loss, no_grad_set or
|
|
|
|
|
set())
|
|
|
|
|
|
|
|
|
|
if no_grad_set is None:
|
|
|
|
|
program = loss.block.program
|
|
|
|
|
assert isinstance(program, framework.Program)
|
|
|
|
|
no_grad_set = list()
|
|
|
|
|
for block in program.blocks:
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
for var in block.vars.itervalues():
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
if var.stop_gradient:
|
|
|
|
|
no_grad_set.append(var.name)
|
|
|
|
|
no_grad_set = set(no_grad_set)
|
|
|
|
|
|
|
|
|
|
param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
|
|
|
|
|
if parameter_list is not None:
|
|
|
|
|
parameters = parameter_list
|
|
|
|
|
else:
|
|
|
|
|