|
|
@ -1,6 +1,7 @@
|
|
|
|
from paddle.v2.fluid import framework as framework
|
|
|
|
from paddle.v2.fluid import framework as framework
|
|
|
|
from . import core
|
|
|
|
from . import core
|
|
|
|
import collections
|
|
|
|
import collections
|
|
|
|
|
|
|
|
import pdb
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['append_backward_ops']
|
|
|
|
__all__ = ['append_backward_ops']
|
|
|
|
|
|
|
|
|
|
|
@ -15,7 +16,8 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
|
|
|
|
op_desc_list[i].rename_output(old_name, new_name)
|
|
|
|
op_desc_list[i].rename_output(old_name, new_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def backward_impl(block,
|
|
|
|
def backward_impl(target,
|
|
|
|
|
|
|
|
block,
|
|
|
|
target_block,
|
|
|
|
target_block,
|
|
|
|
no_grad_set,
|
|
|
|
no_grad_set,
|
|
|
|
grad_info_map,
|
|
|
|
grad_info_map,
|
|
|
@ -29,8 +31,8 @@ def backward_impl(block,
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
|
sub_block_idx = each_op.block_attr("sub_block")
|
|
|
|
sub_block = program.block(sub_block_idx)
|
|
|
|
sub_block = program.block(sub_block_idx)
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
|
|
|
|
backward_impl(sub_block, grad_sub_block, no_grad_set, grad_info_map,
|
|
|
|
backward_impl(target, sub_block, grad_sub_block, no_grad_set,
|
|
|
|
callback)
|
|
|
|
grad_info_map, callback)
|
|
|
|
grad_sub_block_list.append(grad_sub_block)
|
|
|
|
grad_sub_block_list.append(grad_sub_block)
|
|
|
|
grad_op_desc = core.get_grad_op_desc(each_op.desc,
|
|
|
|
grad_op_desc = core.get_grad_op_desc(each_op.desc,
|
|
|
|
no_grad_set[block.idx],
|
|
|
|
no_grad_set[block.idx],
|
|
|
@ -46,6 +48,7 @@ def backward_impl(block,
|
|
|
|
for pos, op_desc in enumerate(grad_op_descs):
|
|
|
|
for pos, op_desc in enumerate(grad_op_descs):
|
|
|
|
for var_name in op_desc.input_arg_names():
|
|
|
|
for var_name in op_desc.input_arg_names():
|
|
|
|
if len(var_inputs[var_name]) > 1:
|
|
|
|
if len(var_inputs[var_name]) > 1:
|
|
|
|
|
|
|
|
pdb.set_trace()
|
|
|
|
pending_sum_ops.append((core.OpDesc(
|
|
|
|
pending_sum_ops.append((core.OpDesc(
|
|
|
|
type="sum_op",
|
|
|
|
type="sum_op",
|
|
|
|
inputs=var_inputs[var_name],
|
|
|
|
inputs=var_inputs[var_name],
|
|
|
@ -55,7 +58,7 @@ def backward_impl(block,
|
|
|
|
for var_name in op_desc.output_arg_names():
|
|
|
|
for var_name in op_desc.output_arg_names():
|
|
|
|
if len(var_inputs[var_name]) == 0:
|
|
|
|
if len(var_inputs[var_name]) == 0:
|
|
|
|
# it's the first time we get the variable
|
|
|
|
# it's the first time we get the variable
|
|
|
|
var_inputs[var_name] = var_name
|
|
|
|
var_inputs[var_name] = [var_name]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if len(var_inputs[var_name] == 1):
|
|
|
|
if len(var_inputs[var_name] == 1):
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
|
new_name = var_name + "@RENAME@" + \
|
|
|
@ -73,8 +76,9 @@ def backward_impl(block,
|
|
|
|
var_inputs[var_name].append(new_name)
|
|
|
|
var_inputs[var_name].append(new_name)
|
|
|
|
for var_name, inputs in var_inputs.iteritems():
|
|
|
|
for var_name, inputs in var_inputs.iteritems():
|
|
|
|
if len(inputs) > 1:
|
|
|
|
if len(inputs) > 1:
|
|
|
|
pending_sum_ops.append((core.OpDesc(
|
|
|
|
pdb.set_trace()
|
|
|
|
type="sum_op", inputs=inputs, outputs=var_name, attrs={}),
|
|
|
|
pending_sum_ops.append((core.OpDesc("sum_op", {"X": inputs},
|
|
|
|
|
|
|
|
{"Out": var_name}, {}),
|
|
|
|
len(grad_op_descs)))
|
|
|
|
len(grad_op_descs)))
|
|
|
|
# TODO: remove op in no grad set
|
|
|
|
# TODO: remove op in no grad set
|
|
|
|
|
|
|
|
|
|
|
@ -84,6 +88,7 @@ def backward_impl(block,
|
|
|
|
# create new gradient variables in the target block desc
|
|
|
|
# create new gradient variables in the target block desc
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
for grad_var_name in op_desc.output_arg_names():
|
|
|
|
for grad_var_name in op_desc.output_arg_names():
|
|
|
|
|
|
|
|
grad_var_name = grad_var_name.encode("ascii")
|
|
|
|
if target_block.desc.has_var(
|
|
|
|
if target_block.desc.has_var(
|
|
|
|
grad_var_name) or grad_var_name == core.get_empty_var_name(
|
|
|
|
grad_var_name) or grad_var_name == core.get_empty_var_name(
|
|
|
|
):
|
|
|
|
):
|
|
|
@ -93,6 +98,16 @@ def backward_impl(block,
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
|
|
|
|
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
|
|
|
|
target_block)
|
|
|
|
target_block)
|
|
|
|
|
|
|
|
if target_block.idx == 0:
|
|
|
|
|
|
|
|
grad_target_name = (target.name + "@GRAD")
|
|
|
|
|
|
|
|
target_block.desc.var(grad_target_name)
|
|
|
|
|
|
|
|
grad_op_descs.insert(
|
|
|
|
|
|
|
|
0,
|
|
|
|
|
|
|
|
core.OpDesc(u"fill_constant", {}, {
|
|
|
|
|
|
|
|
u"Out": [unicode(grad_target_name, "ascii")]
|
|
|
|
|
|
|
|
}, {u"shape": (1),
|
|
|
|
|
|
|
|
u"value": 1.0,
|
|
|
|
|
|
|
|
u"dtype": core.DataType.FP32}))
|
|
|
|
# insert backward operators to target_block
|
|
|
|
# insert backward operators to target_block
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
target_block.desc.append_allocated_op(op_desc)
|
|
|
|
target_block.desc.append_allocated_op(op_desc)
|
|
|
@ -118,18 +133,22 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
assert isinstance(loss, framework.Variable)
|
|
|
|
|
|
|
|
|
|
|
|
if no_grad_set is None:
|
|
|
|
if no_grad_set is None:
|
|
|
|
|
|
|
|
no_grad_set = dict()
|
|
|
|
program = loss.block.program
|
|
|
|
program = loss.block.program
|
|
|
|
assert isinstance(program, framework.Program)
|
|
|
|
assert isinstance(program, framework.Program)
|
|
|
|
no_grad_set = list()
|
|
|
|
|
|
|
|
for block in program.blocks:
|
|
|
|
for block in program.blocks:
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
assert isinstance(block, framework.Block)
|
|
|
|
|
|
|
|
block_no_grad_set = set()
|
|
|
|
for var in block.vars.itervalues():
|
|
|
|
for var in block.vars.itervalues():
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
assert isinstance(var, framework.Variable)
|
|
|
|
if var.stop_gradient:
|
|
|
|
if var.stop_gradient:
|
|
|
|
no_grad_set.append(var.name)
|
|
|
|
block_no_grad_set.add(var.name)
|
|
|
|
no_grad_set = set(no_grad_set)
|
|
|
|
no_grad_set[block.idx] = block_no_grad_set
|
|
|
|
|
|
|
|
|
|
|
|
param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
|
|
|
|
grad_info_map = dict()
|
|
|
|
|
|
|
|
root_block = loss.block.program.block(0)
|
|
|
|
|
|
|
|
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
|
|
|
|
|
|
|
|
pdb.set_trace()
|
|
|
|
if parameter_list is not None:
|
|
|
|
if parameter_list is not None:
|
|
|
|
parameters = parameter_list
|
|
|
|
parameters = parameter_list
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -137,9 +156,9 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
|
|
|
|
parameters = [param.name for param in params]
|
|
|
|
parameters = [param.name for param in params]
|
|
|
|
params_and_grads = []
|
|
|
|
params_and_grads = []
|
|
|
|
for param in parameters:
|
|
|
|
for param in parameters:
|
|
|
|
if param not in param_grad_map:
|
|
|
|
if param not in grad_info_map:
|
|
|
|
raise ValueError("param %s is not in map" % param)
|
|
|
|
raise ValueError("param %s is not in map" % param)
|
|
|
|
grad_info = param_grad_map[param]
|
|
|
|
grad_info = grad_info_map[param]
|
|
|
|
grad_block = loss.block.program.block(grad_info[1])
|
|
|
|
grad_block = loss.block.program.block(grad_info[1])
|
|
|
|
if not grad_block.has_var(grad_info[0]):
|
|
|
|
if not grad_block.has_var(grad_info[0]):
|
|
|
|
raise ValueError("grad block[{0}] did not have grad var {1}".format(
|
|
|
|
raise ValueError("grad block[{0}] did not have grad var {1}".format(
|
|
|
|