diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 3bddd6d5d0..e2da1618bf 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. + Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect. loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If set, overwrite the level setting. """ diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 844480d20d..8dfa6fd80d 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -174,7 +174,7 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O0", "O3"]) self._eval_indexes = [0, 1, 2] if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):