Fix gradients (#20857)

* fix_gradients

* fix_gradients, test=develop
yaoxuefeng
lvmengsi 6 years ago committed by GitHub
parent 03ba0fdae6
commit aadd81b662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1247,9 +1247,13 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
target = targets[i]
if grad is None:
grad_name = _append_grad_suffix_(target.name)
target_shape = paddle.fluid.layers.shape(target)
target_shape = target.name + '_shape'
block.desc.append_op().copy_from(
_create_op_desc_("shape", {'Input': [target.name]},
{"Out": [target_shape]}, {}))
input_grad_names_set.add(target_shape)
op_desc = _create_op_desc_("fill_constant",
{"ShapeTensor": [target_shape.name]},
{"ShapeTensor": [target_shape]},
{"Out": [grad_name]}, {
"shape": target.shape,
"value": 1.0,

Loading…
Cancel
Save