|
|
|
@ -16,7 +16,8 @@ from paddle.fluid import framework as framework
|
|
|
|
|
from . import core
|
|
|
|
|
import collections
|
|
|
|
|
import copy
|
|
|
|
|
import unique_name
|
|
|
|
|
import six
|
|
|
|
|
from . import unique_name
|
|
|
|
|
|
|
|
|
|
__all__ = ['append_backward']
|
|
|
|
|
|
|
|
|
@ -44,17 +45,25 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
|
|
|
|
|
"""
|
|
|
|
|
op_desc = core.OpDesc()
|
|
|
|
|
op_desc.set_type(op_type)
|
|
|
|
|
for para, args in inputs.iteritems():
|
|
|
|
|
op_desc.set_input(para, args)
|
|
|
|
|
for para, args in outputs.iteritems():
|
|
|
|
|
op_desc.set_output(para, args)
|
|
|
|
|
for para, args in list(inputs.items()):
|
|
|
|
|
op_desc.set_input(
|
|
|
|
|
para,
|
|
|
|
|
list(
|
|
|
|
|
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
|
|
|
|
|
args)))
|
|
|
|
|
for para, args in list(outputs.items()):
|
|
|
|
|
op_desc.set_output(
|
|
|
|
|
para,
|
|
|
|
|
list(
|
|
|
|
|
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
|
|
|
|
|
args)))
|
|
|
|
|
|
|
|
|
|
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
|
|
|
|
|
|
|
|
if op_role_attr_name not in attrs:
|
|
|
|
|
attrs[
|
|
|
|
|
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
|
|
|
|
|
for name, val in attrs.iteritems():
|
|
|
|
|
for name, val in list(attrs.items()):
|
|
|
|
|
if isinstance(val, framework.Block):
|
|
|
|
|
op_desc.set_block_attr(name, val.desc)
|
|
|
|
|
else:
|
|
|
|
@ -105,7 +114,9 @@ def _strip_grad_suffix_(name):
|
|
|
|
|
e.g. x@GRAD ==> x
|
|
|
|
|
y@GRAD@RENAME@1 ==> y
|
|
|
|
|
"""
|
|
|
|
|
pos = name.find(core.grad_var_suffix())
|
|
|
|
|
if isinstance(name, six.text_type):
|
|
|
|
|
name = name.encode()
|
|
|
|
|
pos = name.find(six.b(core.grad_var_suffix()))
|
|
|
|
|
return name[:pos] if pos != -1 else name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -114,7 +125,9 @@ def _append_grad_suffix_(name):
|
|
|
|
|
Append grad suffix to the given variable name
|
|
|
|
|
e.g. x ==> x@GRAD
|
|
|
|
|
"""
|
|
|
|
|
return name + core.grad_var_suffix()
|
|
|
|
|
if isinstance(name, six.text_type):
|
|
|
|
|
name = name.encode()
|
|
|
|
|
return name + six.b(core.grad_var_suffix())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _addup_repetitive_outputs_(op_descs):
|
|
|
|
@ -174,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
|
|
|
|
|
op_desc.set_output(param_name, arg_names)
|
|
|
|
|
renamed_vars[var_name].append(new_name)
|
|
|
|
|
|
|
|
|
|
for var_name, inputs in renamed_vars.iteritems():
|
|
|
|
|
for var_name, inputs in list(renamed_vars.items()):
|
|
|
|
|
if len(inputs) > 1:
|
|
|
|
|
pending_sum_ops.append(
|
|
|
|
|
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
|
|
|
|
@ -198,16 +211,19 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
|
|
|
|
|
out_arg_names = op_desc.output_arg_names()
|
|
|
|
|
if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
|
|
|
|
|
return True
|
|
|
|
|
if _all_in_set_(
|
|
|
|
|
filter(lambda name: name.find(core.grad_var_suffix()) != -1,
|
|
|
|
|
op_desc.input_arg_names()), no_grad_set):
|
|
|
|
|
if _all_in_set_([
|
|
|
|
|
name for name in op_desc.input_arg_names()
|
|
|
|
|
if name.find(core.grad_var_suffix()) != -1
|
|
|
|
|
], no_grad_set):
|
|
|
|
|
no_grad_set.update(out_arg_names)
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Remove ops whose outputs are all in no_grad_dict
|
|
|
|
|
op_descs = filter(
|
|
|
|
|
lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
|
|
|
|
|
op_descs = [
|
|
|
|
|
op_desc for op_desc in op_descs
|
|
|
|
|
if not _op_can_be_removed_(op_desc, no_grad_set)
|
|
|
|
|
]
|
|
|
|
|
# Insert fill_zeros_like_op
|
|
|
|
|
to_insert = []
|
|
|
|
|
for idx, op_desc in enumerate(op_descs):
|
|
|
|
@ -217,12 +233,12 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
|
|
|
|
|
"X": [_strip_grad_suffix_(arg)]
|
|
|
|
|
}, {"Out": [arg]}, {}), idx))
|
|
|
|
|
|
|
|
|
|
map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))
|
|
|
|
|
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
|
|
|
|
|
|
|
|
|
|
return op_descs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import proto.framework_pb2 as framework_pb2
|
|
|
|
|
from .proto import framework_pb2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def serialize_op_decs(op_desc):
|
|
|
|
@ -244,8 +260,10 @@ def _callback_lookup_(op):
|
|
|
|
|
if op.type == 'parallel_do' and op.attr('use_nccl'):
|
|
|
|
|
all_vars = op.block.vars
|
|
|
|
|
param_names = set(op.input('parameters'))
|
|
|
|
|
param_names = filter(lambda name: all_vars[name].stop_gradient is False,
|
|
|
|
|
param_names)
|
|
|
|
|
param_names = [
|
|
|
|
|
name for name in param_names
|
|
|
|
|
if all_vars[name].stop_gradient is False
|
|
|
|
|
]
|
|
|
|
|
param_grad_names = [n + "@GRAD" for n in param_names]
|
|
|
|
|
|
|
|
|
|
class ParallelDoCallBack(object):
|
|
|
|
@ -399,7 +417,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|
|
continue
|
|
|
|
|
block.desc.var(grad_var_name)
|
|
|
|
|
new_vars.add(grad_var_name)
|
|
|
|
|
if not grad_to_var.has_key(grad_var_name):
|
|
|
|
|
if grad_var_name not in grad_to_var:
|
|
|
|
|
continue
|
|
|
|
|
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
|
|
|
|
|
# infer_shape and infer_type
|
|
|
|
@ -427,7 +445,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
|
|
|
|
|
op_desc.rename_output(name, new_name)
|
|
|
|
|
var_map[name] = new_name
|
|
|
|
|
|
|
|
|
|
for g, ng in var_map.iteritems():
|
|
|
|
|
for g, ng in list(var_map.items()):
|
|
|
|
|
if g in grad_to_var:
|
|
|
|
|
grad_to_var[ng] = grad_to_var[g]
|
|
|
|
|
grad_to_var.pop(g)
|
|
|
|
@ -439,7 +457,7 @@ def _get_stop_gradients_(program):
|
|
|
|
|
for block in program.blocks:
|
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
block_no_grad_set = set()
|
|
|
|
|
for var in block.vars.itervalues():
|
|
|
|
|
for var in list(block.vars.values()):
|
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
|
if var.stop_gradient:
|
|
|
|
|
block_no_grad_set.add(_append_grad_suffix_(var.name))
|
|
|
|
@ -535,7 +553,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
no_grad_set = set()
|
|
|
|
|
no_grad_set = copy.copy(no_grad_set)
|
|
|
|
|
no_grad_dict = _get_stop_gradients_(program)
|
|
|
|
|
no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set))
|
|
|
|
|
no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set)))
|
|
|
|
|
|
|
|
|
|
grad_info_map = dict()
|
|
|
|
|
root_block = program.block(0)
|
|
|
|
@ -558,7 +576,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
|
|
|
|
|
|
|
|
|
|
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
|
|
|
|
|
op_path = _find_op_path_(root_block, [loss], [], block_no_grad_set)
|
|
|
|
|
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
|
|
|
|
|
no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set)))
|
|
|
|
|
|
|
|
|
|
_append_backward_ops_(root_block, op_path, root_block, no_grad_dict,
|
|
|
|
|
grad_to_var, callbacks)
|
|
|
|
@ -699,7 +717,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
|
|
|
|
|
no_grad_set = set()
|
|
|
|
|
no_grad_set = copy.copy(no_grad_set)
|
|
|
|
|
no_grad_dict = _get_stop_gradients_(prog)
|
|
|
|
|
no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set))
|
|
|
|
|
no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set)))
|
|
|
|
|
|
|
|
|
|
fwd_op_num = block.desc.op_size()
|
|
|
|
|
|
|
|
|
@ -733,7 +751,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
|
|
|
|
|
|
|
|
|
|
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
|
|
|
|
|
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set)
|
|
|
|
|
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
|
|
|
|
|
no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set)))
|
|
|
|
|
grad_to_var = dict()
|
|
|
|
|
grad_info_map = dict()
|
|
|
|
|
_append_backward_ops_(block, op_path, block, no_grad_dict, grad_to_var)
|
|
|
|
|