|
|
|
@ -27,13 +27,12 @@ from src.config import config
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Deeplabv3 training")
|
|
|
|
parser = argparse.ArgumentParser(description="Deeplabv3 training")
|
|
|
|
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
|
|
|
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
|
|
|
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
|
|
|
|
parser.add_argument('--epoch_size', type=int, default=6, help='Epoch size.')
|
|
|
|
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
|
|
|
|
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
|
|
|
|
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
|
|
|
|
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
|
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
|
|
|
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
|
|
|
|
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
|
|
|
|
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
|
|
|
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
|
|
|
parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.')
|
|
|
|
|
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
|
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
|
|
|
|
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
|
|
|
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
@ -80,7 +79,7 @@ if __name__ == "__main__":
|
|
|
|
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
|
|
|
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
|
|
|
callback.append(ckpoint_cb)
|
|
|
|
callback.append(ckpoint_cb)
|
|
|
|
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
|
|
|
|
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
|
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
|
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
|
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
|
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
|
|
|
|
|