amp add `best_choice` level

pull/5693/head
Wei Luning 5 years ago
parent bdf2082a3e
commit 5e1cba77f0

@ -121,12 +121,15 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2, O3]. Default: "O0".
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
- auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
level to O3 Ascend. The recommended level is choose by the export experience, cannot
always generalize. User should specify the level for special network.
O2 is recommended on GPU, O3 is recommended on Ascend.
@ -139,7 +142,17 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
"""
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2', 'O3'], Rel.IN, None)
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN, None)
if level == "auto":
device_target = context.get_context('device_target')
if device_target == "GPU":
level = "O2"
elif device_target == "Ascend":
level = "O3"
else:
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
_check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
config = edict(config)

Loading…
Cancel
Save