|
|
|
|
@ -80,9 +80,9 @@ if __name__ == "__main__":
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
|
|
|
|
|
callback.append(ckpoint_cb)
|
|
|
|
|
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
|
|
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
|
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
|
|
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
|
|
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
|
|
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
|
|
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
|
|
|
|
|
net.set_train()
|
|
|
|
|
model_fine_tune(args_opt, net, 'layer')
|
|
|
|
|
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
|
|
|
|
|
|