fix add grad bug test=develop (#22924)

* fix add grad bug test=develop

* update style test=develop
revert-22710-feature/integrated_ps_api
yaoxuefeng 5 years ago committed by GitHub
parent 9848f8f388
commit c5cbe7f07b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -189,6 +189,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
sparse_table_index = 0
for loss in losses:
prog_id = str(id(loss.block.program))
# param_grads of program
params_grads = sorted(
fluid.backward.append_backward(loss, parameter_list,
no_grad_set),
key=lambda x: x[0].name)
if prog_id not in program_id_set:
program_id_set.add(prog_id)
sparse_table = self._find_multi_distributed_lookup_table([loss])
@ -215,11 +221,6 @@ class DistributedAdam(DistributedOptimizerImplBase):
loss.block.program, sparse_table)
prog_id_to_sparse_grads[prog_id] = grads_dict
# param_grads of program
params_grads = sorted(
fluid.backward.append_backward(loss, parameter_list,
no_grad_set),
key=lambda x: x[0].name)
if prog_id not in prog_id_to_param_grads:
prog_id_to_param_grads[prog_id] = []
prog_id_to_param_grads[prog_id].append(params_grads)

Loading…
Cancel
Save