|
|
|
@ -146,10 +146,14 @@ def main():
|
|
|
|
|
loss_scale_manager = FixedLossScaleManager(
|
|
|
|
|
cfg.loss_scale, drop_overflow_update=False)
|
|
|
|
|
|
|
|
|
|
config_ck = CheckpointConfig(
|
|
|
|
|
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(
|
|
|
|
|
prefix=cfg.model, directory=output_dir, config=config_ck)
|
|
|
|
|
callbacks = [time_cb, loss_cb]
|
|
|
|
|
|
|
|
|
|
if cfg.save_checkpoint:
|
|
|
|
|
config_ck = CheckpointConfig(
|
|
|
|
|
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(
|
|
|
|
|
prefix=cfg.model, directory=output_dir, config=config_ck)
|
|
|
|
|
callbacks += [ckpoint_cb]
|
|
|
|
|
|
|
|
|
|
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
|
|
|
|
|
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
|
|
|
|
@ -176,7 +180,7 @@ def main():
|
|
|
|
|
amp_level=cfg.amp_level
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else []
|
|
|
|
|
callbacks = callbacks if is_master else []
|
|
|
|
|
|
|
|
|
|
if args.resume:
|
|
|
|
|
real_epoch = cfg.epochs - cfg.resume_start_epoch
|
|
|
|
|