|
|
|
@ -3,7 +3,7 @@ from . import core
|
|
|
|
|
import collections
|
|
|
|
|
import pdb
|
|
|
|
|
|
|
|
|
|
__all__ = ['append_backward_ops']
|
|
|
|
|
__all__ = ['append_backward']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
|
|
|
|
@ -57,12 +57,11 @@ def _append_grad_suffix_(name):
|
|
|
|
|
return name + core.grad_var_suffix()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _backward_impl_(target,
|
|
|
|
|
block,
|
|
|
|
|
target_block,
|
|
|
|
|
no_grad_set,
|
|
|
|
|
grad_info_map,
|
|
|
|
|
callback=None):
|
|
|
|
|
def _append_backward_ops_(target,
|
|
|
|
|
block,
|
|
|
|
|
target_block,
|
|
|
|
|
no_grad_set,
|
|
|
|
|
callback=None):
|
|
|
|
|
grad_op_descs = []
|
|
|
|
|
grad_to_var = dict()
|
|
|
|
|
program = block.program
|
|
|
|
@ -71,11 +70,10 @@ def _backward_impl_(target,
|
|
|
|
|
if each_op.has_attr("sub_block"):
|
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
|
|
sub_block = program.block(sub_block_idx)
|
|
|
|
|
original_block_idx = program.current_block_idx
|
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
|
|
|
|
|
program.current_block_idx = original_block_idx
|
|
|
|
|
_backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
|
grad_info_map, callback)
|
|
|
|
|
sub_grad_to_var = _append_backward_ops_(
|
|
|
|
|
target, sub_block, grad_sub_block, no_grad_set, callback)
|
|
|
|
|
grad_to_var = dict(grad_to_var, **sub_grad_to_var)
|
|
|
|
|
grad_sub_block_list.append(grad_sub_block.desc)
|
|
|
|
|
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
|
|
|
|
|
each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
|
|
|
|
@ -143,20 +141,7 @@ def _backward_impl_(target,
|
|
|
|
|
"fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
|
|
|
|
|
{})
|
|
|
|
|
grad_op_descs.insert(ele[1], fill_zeros_like_op)
|
|
|
|
|
# create new gradient variables in the target block desc
|
|
|
|
|
new_vars = set()
|
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
for grad_var_name in op_desc.output_arg_names():
|
|
|
|
|
grad_var_name = grad_var_name.encode("ascii")
|
|
|
|
|
if target_block.desc.has_var_recursive(
|
|
|
|
|
grad_var_name) or grad_var_name == core.empty_var_name():
|
|
|
|
|
continue
|
|
|
|
|
target_block.desc.var(grad_var_name)
|
|
|
|
|
new_vars.add(grad_var_name)
|
|
|
|
|
if not grad_to_var.has_key(grad_var_name):
|
|
|
|
|
continue
|
|
|
|
|
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
|
|
|
|
|
target_block)
|
|
|
|
|
|
|
|
|
|
if target_block.idx == 0:
|
|
|
|
|
grad_target_name = _append_grad_suffix_(target.name)
|
|
|
|
|
target_block.desc.var(grad_target_name.encode("ascii"))
|
|
|
|
@ -171,20 +156,40 @@ def _backward_impl_(target,
|
|
|
|
|
"value": 1.0,
|
|
|
|
|
"dtype": core.DataType.FP32
|
|
|
|
|
}))
|
|
|
|
|
# insert backward operators to target_block
|
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
op_desc.infer_var_type(target_block.desc)
|
|
|
|
|
op_desc.infer_shape(target_block.desc)
|
|
|
|
|
for arg in op_desc.output_arg_names():
|
|
|
|
|
if arg in new_vars:
|
|
|
|
|
_infer_var_data_type_(arg, target_block)
|
|
|
|
|
new_op_desc = target_block.desc.append_op()
|
|
|
|
|
new_op_desc.copy_from(op_desc)
|
|
|
|
|
|
|
|
|
|
target_block.sync_with_cpp()
|
|
|
|
|
return grad_to_var
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|
|
for op_idx in range(start_op_idx, block.desc.op_size()):
|
|
|
|
|
op_desc = block.desc.op(op_idx)
|
|
|
|
|
if op_desc.has_attr("sub_block"):
|
|
|
|
|
sub_block = block.program.block(op_desc.block_attr("sub_block"))
|
|
|
|
|
_append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
|
|
|
|
|
new_vars = set()
|
|
|
|
|
# create new gradient variables
|
|
|
|
|
for grad_var_name in op_desc.output_arg_names():
|
|
|
|
|
grad_var_name = grad_var_name.encode("ascii")
|
|
|
|
|
if block.desc.has_var_recursive(
|
|
|
|
|
grad_var_name) or grad_var_name == core.empty_var_name():
|
|
|
|
|
continue
|
|
|
|
|
block.desc.var(grad_var_name)
|
|
|
|
|
new_vars.add(grad_var_name)
|
|
|
|
|
if not grad_to_var.has_key(grad_var_name):
|
|
|
|
|
continue
|
|
|
|
|
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
|
|
|
|
|
# infer_shape and infer_type
|
|
|
|
|
op_desc.infer_var_type(block.desc)
|
|
|
|
|
op_desc.infer_shape(block.desc)
|
|
|
|
|
for arg in op_desc.output_arg_names():
|
|
|
|
|
if arg in new_vars:
|
|
|
|
|
_infer_var_data_type_(arg, block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
def append_backward(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
"""
|
|
|
|
|
Create and add gradient Operators in BlockDesc to compute
|
|
|
|
|
gradients of `loss` for parameters in parameter_list
|
|
|
|
@ -201,9 +206,9 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
|
|
|
|
|
|
program = loss.block.program
|
|
|
|
|
if no_grad_set is None:
|
|
|
|
|
no_grad_set = dict()
|
|
|
|
|
program = loss.block.program
|
|
|
|
|
assert isinstance(program, framework.Program)
|
|
|
|
|
for block in program.blocks:
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
@ -215,14 +220,20 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
no_grad_set[block.idx] = block_no_grad_set
|
|
|
|
|
|
|
|
|
|
grad_info_map = dict()
|
|
|
|
|
root_block = loss.block.program.block(0)
|
|
|
|
|
root_block = program.block(0)
|
|
|
|
|
|
|
|
|
|
_backward_impl_(loss, root_block, root_block, no_grad_set, grad_info_map)
|
|
|
|
|
fwd_op_num = root_block.desc.op_size()
|
|
|
|
|
current_block_idx = program.current_block_idx
|
|
|
|
|
grad_to_var = _append_backward_ops_(loss, root_block, root_block,
|
|
|
|
|
no_grad_set)
|
|
|
|
|
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
|
|
|
|
|
program.current_block_idx = current_block_idx
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
if parameter_list is not None:
|
|
|
|
|
parameters = parameter_list
|
|
|
|
|
else:
|
|
|
|
|
params = loss.block.program.global_block().all_parameters()
|
|
|
|
|
params = program.global_block().all_parameters()
|
|
|
|
|
parameters = [param.name for param in params]
|
|
|
|
|
params_and_grads = []
|
|
|
|
|
for param in parameters:
|
|
|
|
@ -234,7 +245,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
|
raise ValueError("grad block[{0}] did not have grad var {1}".format(
|
|
|
|
|
grad_info[1], grad_info[0]))
|
|
|
|
|
# Get the param var from the global block
|
|
|
|
|
param_var = loss.block.program.global_block().var(param)
|
|
|
|
|
param_var = program.global_block().var(param)
|
|
|
|
|
grad_var = grad_block.var(grad_info[0])
|
|
|
|
|
if loss.block.has_var(grad_info[0]):
|
|
|
|
|
params_and_grads.append((param_var, grad_var))
|
|
|
|
|