|
|
|
@ -82,11 +82,11 @@ def _check_kwargs(key_words):
|
|
|
|
|
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
|
|
|
|
|
[mstype.float16, mstype.float32], None)
|
|
|
|
|
if 'keep_batchnorm_fp32' in key_words:
|
|
|
|
|
validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None)
|
|
|
|
|
validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
|
|
|
|
|
if 'loss_scale_manager' in key_words:
|
|
|
|
|
loss_scale_manager = key_words['loss_scale_manager']
|
|
|
|
|
if loss_scale_manager:
|
|
|
|
|
validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None)
|
|
|
|
|
validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
|
@ -104,7 +104,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
|
|
label = F.mixed_precision_cast(mstype.float32, label)
|
|
|
|
|
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
|
|
|
|
|
|
|
|
|
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
|
|
|
|
|
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
|
|
|
|
if cast_model_type == mstype.float16:
|
|
|
|
|
network = WithLossCell(network, loss_fn)
|
|
|
|
|
else:
|
|
|
|
@ -140,9 +140,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|
|
|
|
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
|
|
|
|
scale the loss by LossScaleManager. If set, overwrite the level setting.
|
|
|
|
|
"""
|
|
|
|
|
validator.check_value_type('network', network, nn.Cell, None)
|
|
|
|
|
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
|
|
|
|
|
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN, None)
|
|
|
|
|
validator.check_value_type('network', network, nn.Cell)
|
|
|
|
|
validator.check_value_type('optimizer', optimizer, nn.Optimizer)
|
|
|
|
|
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN)
|
|
|
|
|
|
|
|
|
|
if level == "auto":
|
|
|
|
|
device_target = context.get_context('device_target')
|
|
|
|
|