|
|
|
@ -34,8 +34,8 @@ set_seed(1)
|
|
|
|
|
parser = argparse.ArgumentParser(description="crnn training")
|
|
|
|
|
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
|
|
|
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
|
|
|
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
|
|
|
|
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
|
|
|
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'],
|
|
|
|
|
help='Running platform, only support Ascend now. Default is Ascend.')
|
|
|
|
|
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
|
|
|
|
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
|
|
|
|
parser.set_defaults(run_distribute=False)
|
|
|
|
@ -92,7 +92,7 @@ if __name__ == '__main__':
|
|
|
|
|
model = Model(net)
|
|
|
|
|
# define callbacks
|
|
|
|
|
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
|
|
|
|
if config.save_checkpoint:
|
|
|
|
|
if config.save_checkpoint and rank == 0:
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
|
|
|
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
|
|
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
|
|
|
|