|
|
|
@ -19,6 +19,7 @@ from ... import unique_name
|
|
|
|
|
from . import fp16_utils
|
|
|
|
|
from .fp16_utils import create_master_params_grads, master_param_to_train_param
|
|
|
|
|
from .fp16_utils import update_loss_scaling, rewrite_program
|
|
|
|
|
from .fp16_lists import AutoMixedPrecisionLists
|
|
|
|
|
|
|
|
|
|
__all__ = ["decorate"]
|
|
|
|
|
|
|
|
|
@ -34,6 +35,7 @@ class OptimizerWithMixedPrecison(object):
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
optimizer (Optimizer): A common Optimizer object.
|
|
|
|
|
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
|
|
|
|
|
init_loss_scaling (float): The initial loss scaling factor.
|
|
|
|
|
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
|
|
|
|
|
incr_every_n_steps(int): Increases loss scaling every n consecutive
|
|
|
|
@ -48,10 +50,11 @@ class OptimizerWithMixedPrecison(object):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, optimizer, init_loss_scaling, use_dynamic_loss_scaling,
|
|
|
|
|
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
|
|
|
|
|
decr_ratio):
|
|
|
|
|
def __init__(self, optimizer, amp_lists, init_loss_scaling,
|
|
|
|
|
use_dynamic_loss_scaling, incr_every_n_steps,
|
|
|
|
|
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
|
|
|
|
|
self._optimizer = optimizer
|
|
|
|
|
self._amp_lists = amp_lists
|
|
|
|
|
self._param_grads = None
|
|
|
|
|
self._train_program = default_main_program()
|
|
|
|
|
self._startup_prog = default_startup_program()
|
|
|
|
@ -120,7 +123,7 @@ class OptimizerWithMixedPrecison(object):
|
|
|
|
|
A list of (param, grad), which is a tuple of a parameter and its
|
|
|
|
|
gradient respectively, and the scaled loss.
|
|
|
|
|
"""
|
|
|
|
|
rewrite_program(self._train_program)
|
|
|
|
|
rewrite_program(self._train_program, self._amp_lists)
|
|
|
|
|
scaled_loss = loss * self._loss_scaling
|
|
|
|
|
self._param_grads = self._optimizer.backward(
|
|
|
|
|
scaled_loss, startup_program, parameter_list, no_grad_set,
|
|
|
|
@ -189,6 +192,7 @@ class OptimizerWithMixedPrecison(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decorate(optimizer,
|
|
|
|
|
amp_lists=None,
|
|
|
|
|
init_loss_scaling=1.0,
|
|
|
|
|
incr_every_n_steps=1000,
|
|
|
|
|
decr_every_n_nan_or_inf=2,
|
|
|
|
@ -200,6 +204,7 @@ def decorate(optimizer,
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
optimizer(Optimizer): A common Optimizer.
|
|
|
|
|
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
|
|
|
|
|
init_loss_scaling(float): The initial loss scaling factor.
|
|
|
|
|
incr_every_n_steps(int): Increases loss scaling every n consecutive
|
|
|
|
|
steps with finite gradients.
|
|
|
|
@ -227,9 +232,10 @@ def decorate(optimizer,
|
|
|
|
|
|
|
|
|
|
scaled_loss, _, _ = mp_optimizer.minimize(loss)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if amp_lists is None:
|
|
|
|
|
amp_lists = AutoMixedPrecisionLists()
|
|
|
|
|
mp_optimizer = OptimizerWithMixedPrecison(
|
|
|
|
|
optimizer, init_loss_scaling, use_dynamic_loss_scaling,
|
|
|
|
|
optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling,
|
|
|
|
|
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
|
|
|
|
|
|
|
|
|
|
return mp_optimizer
|
|
|
|
|