Extract apply_backward_pass to backward.py (#5026)
* Extract apply_backward_pass to backward.py Rename apply_backward_pass to append_backward_ops * Fix CI * Update design docrevert-4814-Add_sequence_project_op
parent
fd2eb55071
commit
dd0008d57f
@ -0,0 +1,45 @@
|
||||
from paddle.v2.framework import framework as framework
|
||||
|
||||
__all__ = ['append_backward_ops']
|
||||
|
||||
|
||||
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
||||
"""
|
||||
Create and add gradient Operators in BlockDesc to compute
|
||||
gradients of `loss` for parameters in parameter_list
|
||||
|
||||
:param loss: an variable generated by cost function.
|
||||
:type loss: Variable
|
||||
:param no_grad_set: variable that should not create gradient
|
||||
:type no_grad_set: set
|
||||
:param parameter_list: parameters that need to compute gradient and
|
||||
update to optimize the lost.
|
||||
:type: list
|
||||
:return: list of (parameters, gradients) pair.
|
||||
:rtype: list[Variable]
|
||||
"""
|
||||
assert isinstance(loss, framework.Variable)
|
||||
param_grad_map = loss.block.program.append_backward(loss, no_grad_set or
|
||||
set())
|
||||
if parameter_list is not None:
|
||||
parameters = parameter_list
|
||||
else:
|
||||
params = loss.block.program.global_block().all_parameters()
|
||||
parameters = [param.name for param in params]
|
||||
params_and_grads = []
|
||||
for param in parameters:
|
||||
if param not in param_grad_map:
|
||||
raise ValueError("param %s is not in map" % param)
|
||||
grad_info = param_grad_map[param]
|
||||
grad_block = loss.block.program.block(grad_info[1])
|
||||
if not grad_block.has_var(grad_info[0]):
|
||||
raise ValueError("grad block[{0}] did not have grad var {1}".format(
|
||||
grad_info[1], grad_info[0]))
|
||||
# Get the param var from the global block
|
||||
param_var = loss.block.program.global_block().var(param)
|
||||
grad_var = grad_block.var(grad_info[0])
|
||||
if loss.block.has_var(grad_info[0]):
|
||||
params_and_grads.append((param_var, grad_var))
|
||||
else:
|
||||
params_and_grads.append((param_var, None))
|
||||
return params_and_grads
|
Loading…
Reference in new issue