|
|
|
@ -147,7 +147,7 @@ def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
else:
|
|
|
|
|
if len(renamed_vars[var_name]) == 1:
|
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
var_rename_count[var_name] += 1
|
|
|
|
|
# rename original var_name
|
|
|
|
|
renamed_vars[var_name][0] = new_name
|
|
|
|
@ -155,7 +155,7 @@ def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
_rename_arg_(pending_sum_ops, var_name, new_name)
|
|
|
|
|
|
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
str(var_rename_count[var_name])
|
|
|
|
|
var_rename_count[var_name] += 1
|
|
|
|
|
op_desc.rename_output(var_name, new_name)
|
|
|
|
|
renamed_vars[var_name].append(new_name)
|
|
|
|
@ -434,18 +434,65 @@ def _get_stop_gradients_(program):
|
|
|
|
|
def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
callbacks=None):
|
|
|
|
|
"""
|
|
|
|
|
Append backward part to main_program
|
|
|
|
|
Append backward part to main_program.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
loss(Variable): The variable generated by cost function.
|
|
|
|
|
parameter_list(list[string]): Parameters that need to be updated by
|
|
|
|
|
optimizer. If None, it means all parameters need to be updated.
|
|
|
|
|
no_grad_set(set): Variables that have no gradients in Block 0.
|
|
|
|
|
All variables with `step_gradient=True` from all blocks will be
|
|
|
|
|
automatically added.
|
|
|
|
|
A complete neural network training is made up of forward and backward
|
|
|
|
|
propagation. However, when we configure a network, we only need to
|
|
|
|
|
specify its forwrd part. The backward part is generated automatically
|
|
|
|
|
according to the forward part by this function.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
|
|
|
|
|
In most cases, users do not need to invoke this function manually. It
|
|
|
|
|
will be automatically invoked by the optimizer's `minimize` function.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
loss(Variable): The loss variable of the network.
|
|
|
|
|
parameter_list(list[string]|None): Names of parameters that need
|
|
|
|
|
to be updated by optimizers.
|
|
|
|
|
If it is None, all parameters
|
|
|
|
|
will be updated.
|
|
|
|
|
Default: None
|
|
|
|
|
no_grad_set(set|None): Variables in the Block 0 whose gradients
|
|
|
|
|
should be ignored. All variables with
|
|
|
|
|
`step_gradient=True` from all blocks will
|
|
|
|
|
be automatically added into this set.
|
|
|
|
|
Default: None
|
|
|
|
|
callbacks(list[callable object]|None): The callbacks are used for
|
|
|
|
|
doing some custom jobs during
|
|
|
|
|
backward part building. All
|
|
|
|
|
callable objects in it will
|
|
|
|
|
be invoked once each time a
|
|
|
|
|
new gradient operator is added
|
|
|
|
|
into the program. The callable
|
|
|
|
|
object must has two input
|
|
|
|
|
parameters: 'block' and 'context'.
|
|
|
|
|
The 'block' is the block which
|
|
|
|
|
the new gradient operator will
|
|
|
|
|
be added to. The 'context' is a
|
|
|
|
|
map, whose keys are gradient
|
|
|
|
|
variable names and values are
|
|
|
|
|
corresponding original variables.
|
|
|
|
|
In addition to this, the 'context'
|
|
|
|
|
has another special key-value pair:
|
|
|
|
|
the key is string '__current_op_desc__'
|
|
|
|
|
and the value is the op_desc of the
|
|
|
|
|
gradient operator who has just
|
|
|
|
|
triggered the callable object.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list[(Variable,Variable)]: Pairs of parameter and its
|
|
|
|
|
corresponding gradients. The key is the parameter and the
|
|
|
|
|
value is gradient variable.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
AssertionError: If `loss` is not an instance of Variable.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
# network configuration code
|
|
|
|
|
# ...
|
|
|
|
|
avg_loss = fluid.layers.mean(loss)
|
|
|
|
|
param_grad_list = fluid.backward.append_backward(loss=avg_loss)
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
|
|
|
|
|
|