|
|
|
@ -62,6 +62,7 @@ class Model:
|
|
|
|
|
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
|
|
|
|
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
|
|
|
|
|
e.g. Use `loss_scale_manager=None` to set the value.
|
|
|
|
|
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
@ -96,7 +97,10 @@ class Model:
|
|
|
|
|
self._optimizer = optimizer
|
|
|
|
|
self._loss_scale_manager = None
|
|
|
|
|
self._loss_scale_manager_set = False
|
|
|
|
|
self._keep_bn_fp32 = True
|
|
|
|
|
self._check_kwargs(kwargs)
|
|
|
|
|
if 'keep_batchnorm_fp32' in kwargs:
|
|
|
|
|
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
|
|
|
|
if 'loss_scale_manager' in kwargs:
|
|
|
|
|
self._loss_scale_manager = kwargs['loss_scale_manager']
|
|
|
|
|
self._loss_scale_manager_set = True
|
|
|
|
@ -112,7 +116,7 @@ class Model:
|
|
|
|
|
|
|
|
|
|
def _check_kwargs(self, kwargs):
|
|
|
|
|
for arg in kwargs:
|
|
|
|
|
if arg not in ['loss_scale_manager']:
|
|
|
|
|
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
|
|
|
|
raise ValueError(f"Unsupport arg '{arg}'")
|
|
|
|
|
|
|
|
|
|
def _build_train_network(self):
|
|
|
|
@ -124,12 +128,14 @@ class Model:
|
|
|
|
|
self._optimizer,
|
|
|
|
|
self._loss_fn,
|
|
|
|
|
level=self._amp_level,
|
|
|
|
|
loss_scale_manager=self._loss_scale_manager)
|
|
|
|
|
loss_scale_manager=self._loss_scale_manager,
|
|
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32)
|
|
|
|
|
else:
|
|
|
|
|
network = amp.build_train_network(network,
|
|
|
|
|
self._optimizer,
|
|
|
|
|
self._loss_fn,
|
|
|
|
|
level=self._amp_level)
|
|
|
|
|
level=self._amp_level,
|
|
|
|
|
keep_batchnorm_fp32=self._keep_bn_fp32)
|
|
|
|
|
elif self._loss_fn:
|
|
|
|
|
network = nn.WithLossCell(network, self._loss_fn)
|
|
|
|
|
# If need to check if loss_fn is not None, but optimizer is None
|
|
|
|
|