fix resnext50 gpu amp

pull/7155/head
zhaoting 4 years ago
parent 8df757143a
commit c470c9acef

@ -240,12 +240,8 @@ def train(cloud_args=None):
else:
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
if args.platform == "Ascend":
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O3")
else:
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O2")
# checkpoint save
progress_cb = ProgressMonitor(args)

Loading…
Cancel
Save