fix API param bug of recompute.backward() (#22582)

* fix API param bug of recompute.backward(), test=develop
revert-22710-feature/integrated_ps_api
mapingshuo 5 years ago committed by GitHub
parent 61fef9754b
commit 08a772cb46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3850,12 +3850,12 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
no_grad_set=None)
program = cost.block.program
with framework.program_guard(program, None):
@ -3871,8 +3871,7 @@ class RecomputeOptimizer(Optimizer):
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None,
checkpoints=None):
callbacks=None):
"""
call append_backward with checkpoints.
@ -3906,12 +3905,12 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
no_grad_set=None)
print("Finished backward")
"""
@ -3958,12 +3957,12 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
no_grad_set=None)
optimize_ops = sgd.apply_optimize(
cost, startup_program=None, params_grads=params_grads)
@ -3993,8 +3992,7 @@ class RecomputeOptimizer(Optimizer):
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
checkpoints=self._checkpoints)
no_grad_set=no_grad_set)
if grad_clip:
# TODO(guru4elephant): should add grad_clip for static graph

@ -791,8 +791,7 @@ class TestRecomputeOptimizer(unittest.TestCase):
mean_out,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[b1_out])
no_grad_set=None)
# apply gradient
program = mean_out.block.program

Loading…
Cancel
Save