[fleet] combine amp and gradient merge, test=develop (#30086)

revert-31562-mean
WangXi 4 years ago committed by GitHub
parent 88e6dc4ac5
commit ab04997846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,6 @@ class AMPOptimizer(MetaOptimizerBase):
"LarsOptimizer",
"LambOptimizer",
"RecomputeOptimizer",
"GradientMergeOptimizer",
"GraphExecutionOptimizer",
]
self.meta_optimizers_black_list = ["DGCOptimizer"]

@ -21,6 +21,7 @@ class GradientMergeOptimizer(MetaOptimizerBase):
self.inner_opt = optimizer
self.wrapped_opt = None
self.meta_optimizers_white_list = [
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"GraphExecutionOptimizer",

@ -159,9 +159,6 @@ class OptimizerWithMixedPrecision(object):
params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad(train_program, params_grads)
return params_grads
def apply_gradients(self, params_grads):
@ -176,6 +173,10 @@ class OptimizerWithMixedPrecision(object):
A list of optimize operators.
"""
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad(self._train_program, params_grads)
grads = [g for _, g in params_grads]
if not self._is_distributed:
with self._train_program._optimized_guard(grads):

@ -46,6 +46,19 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
self.assertIn('@GradientMerge', ''.join(vars))
self.assertIn('subprog', ''.join(vars))
def test_gm_amp_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'gradient_merge')
self.set_strategy(strategy, 'amp')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
print(train_prog)
vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@GradientMerge', ''.join(vars))
self.assertIn('cast', ''.join(vars))
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save