!5693 [feature] add a level `auto` for amp.

Merge pull request !5693 from vlne-v1/amp-best_choice
pull/5693/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 08ba8e71ed

@ -121,12 +121,15 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2, O3]. Default: "O0".
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
- auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
level to O3 Ascend. The recommended level is choose by the export experience, cannot
always generalize. User should specify the level for special network.
O2 is recommended on GPU, O3 is recommended on Ascend.
@ -139,7 +142,17 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
"""
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2', 'O3'], Rel.IN, None)
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN, None)
if level == "auto":
device_target = context.get_context('device_target')
if device_target == "GPU":
level = "O2"
elif device_target == "Ascend":
level = "O3"
else:
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
_check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
config = edict(config)

Loading…
Cancel
Save