|
|
|
@ -121,12 +121,15 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|
|
|
|
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
|
|
|
|
|
Default: None.
|
|
|
|
|
optimizer (Optimizer): Optimizer to update the Parameter.
|
|
|
|
|
level (str): Supports [O0, O2, O3]. Default: "O0".
|
|
|
|
|
level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
|
|
|
|
|
|
|
|
|
|
- O0: Do not change.
|
|
|
|
|
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
|
|
|
|
using dynamic loss scale.
|
|
|
|
|
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
|
|
|
|
- auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
|
|
|
|
|
level to O3 Ascend. The recommended level is choose by the export experience, cannot
|
|
|
|
|
always generalize. User should specify the level for special network.
|
|
|
|
|
|
|
|
|
|
O2 is recommended on GPU, O3 is recommended on Ascend.
|
|
|
|
|
|
|
|
|
@ -139,7 +142,17 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
validator.check_value_type('network', network, nn.Cell, None)
|
|
|
|
|
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
|
|
|
|
|
validator.check('level', level, "", ['O0', 'O2', 'O3'], Rel.IN, None)
|
|
|
|
|
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN, None)
|
|
|
|
|
|
|
|
|
|
if level == "auto":
|
|
|
|
|
device_target = context.get_context('device_target')
|
|
|
|
|
if device_target == "GPU":
|
|
|
|
|
level = "O2"
|
|
|
|
|
elif device_target == "Ascend":
|
|
|
|
|
level = "O3"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
|
|
|
|
|
|
|
|
|
|
_check_kwargs(kwargs)
|
|
|
|
|
config = dict(_config_level[level], **kwargs)
|
|
|
|
|
config = edict(config)
|
|
|
|
|