[cherry-pick] Fix amp support fleet(#29505)

release/2.0-rc1 v2.0.0-rc1
Aurelius84 5 years ago committed by GitHub
parent 4d51cd73b3
commit d82d59e6e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,6 +66,8 @@ class OptimizerWithMixedPrecision(object):
self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = incr_every_n_steps
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf

@ -3751,7 +3751,9 @@ class PipelineOptimizer(object):
if framework.in_dygraph_mode():
raise Exception("In dygraph, don't support PipelineOptimizer.")
if not isinstance(optimizer, Optimizer) and not isinstance(
optimizer, paddle.optimizer.Optimizer):
optimizer, paddle.optimizer.Optimizer) and not isinstance(
optimizer, paddle.fluid.contrib.mixed_precision.decorator.
OptimizerWithMixedPrecision):
raise ValueError("The 'optimizer' parameter for "
"PipelineOptimizer must be an instance of "
"Optimizer, but the given type is {}.".format(

Loading…
Cancel
Save