|
|
|
@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|
|
|
|
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
|
|
|
|
|
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
|
|
|
|
|
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
|
|
|
|
|
Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect.
|
|
|
|
|
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
|
|
|
|
scale the loss by LossScaleManager. If set, overwrite the level setting.
|
|
|
|
|
"""
|
|
|
|
|