|
|
|
@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
|
|
|
|
|
from src.utils.logging import get_logger
|
|
|
|
|
from src.utils.optimizers__init__ import get_param_groups
|
|
|
|
|
from src.image_classification import get_network
|
|
|
|
|
from src.utils.auto_mixed_precision import auto_mixed_precision
|
|
|
|
|
from src.config import config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -273,8 +272,8 @@ def train(cloud_args=None):
|
|
|
|
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
|
|
|
|
metrics={'acc'}, amp_level="O3")
|
|
|
|
|
else:
|
|
|
|
|
auto_mixed_precision(network)
|
|
|
|
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
|
|
|
|
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
|
|
|
|
metrics={'acc'}, amp_level="O2")
|
|
|
|
|
|
|
|
|
|
# checkpoint save
|
|
|
|
|
progress_cb = ProgressMonitor(args)
|
|
|
|
|