|
|
@ -60,7 +60,11 @@ if __name__ == "__main__":
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
|
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck)
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck)
|
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
|
|
|
|
|
|
|
|
|
|
|
|
if args.device_target == "CPU":
|
|
|
|
|
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
|
|
|
|
|
|
|
|
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
|
|
|
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
|
|
|