|
|
|
@ -154,7 +154,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|
|
|
|
loss_scale = loss_scale_manager.get_loss_scale()
|
|
|
|
|
update_cell = loss_scale_manager.get_update_cell()
|
|
|
|
|
if update_cell is not None:
|
|
|
|
|
if not (context.get_context("enable_ge") or (context.get_context("device_target") == "GPU")):
|
|
|
|
|
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
|
|
|
|
|
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
|
|
|
|
|
raise ValueError("Only `loss_scale_manager=None` and "
|
|
|
|
|
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
|
|
|
|
|
"are supported in current version. If you use `O2` option, please"
|
|
|
|
|