You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/v2/fluid/backward.py

77 lines
3.1 KiB

from paddle.v2.fluid import framework as framework
from . import core
__all__ = ['append_backward_ops']
def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
grad_op_descs = []
program = block.program
for each_op in block.ops:
grad_sub_block_list = []
if each_op.has_attr("sub_block"):
sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(sub_block, grad_sub_block, no_grad_set, grad_to_var,
callback)
grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc,
no_grad_set[block.idx],
grad_to_var, grad_sub_block_list)
grad_op_descs.append(grad_op_desc)
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
"""
Create and add gradient Operators in BlockDesc to compute
gradients of `loss` for parameters in parameter_list
:param loss: an variable generated by cost function.
:type loss: Variable
:param no_grad_set: variable that should not create gradient
:type no_grad_set: set
:param parameter_list: parameters that need to compute gradient and
update to optimize the lost.
:type: list
:return: list of (parameters, gradients) pair.
:rtype: list[Variable]
"""
assert isinstance(loss, framework.Variable)
if no_grad_set is None:
program = loss.block.program
assert isinstance(program, framework.Program)
no_grad_set = list()
for block in program.blocks:
assert isinstance(block, framework.Block)
for var in block.vars.itervalues():
assert isinstance(var, framework.Variable)
if var.stop_gradient:
no_grad_set.append(var.name)
no_grad_set = set(no_grad_set)
param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
if parameter_list is not None:
parameters = parameter_list
else:
params = loss.block.program.global_block().all_parameters()
parameters = [param.name for param in params]
params_and_grads = []
for param in parameters:
if param not in param_grad_map:
raise ValueError("param %s is not in map" % param)
grad_info = param_grad_map[param]
grad_block = loss.block.program.block(grad_info[1])
if not grad_block.has_var(grad_info[0]):
raise ValueError("grad block[{0}] did not have grad var {1}".format(
grad_info[1], grad_info[0]))
# Get the param var from the global block
param_var = loss.block.program.global_block().var(param)
grad_var = grad_block.var(grad_info[0])
if loss.block.has_var(grad_info[0]):
params_and_grads.append((param_var, grad_var))
else:
params_and_grads.append((param_var, None))
return params_and_grads