|
|
|
@ -88,6 +88,7 @@ def main():
|
|
|
|
|
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.")
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
|
|
|
|
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
|
|
|
|
|
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
|
|
|
|
|
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
|
|
|
|
|
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
@ -150,17 +151,20 @@ def main():
|
|
|
|
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config)
|
|
|
|
|
|
|
|
|
|
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr,
|
|
|
|
|
warmup_epochs=max(args_opt.epoch_size // 20, 1),
|
|
|
|
|
total_epochs=args_opt.epoch_size,
|
|
|
|
|
steps_per_epoch=dataset_size))
|
|
|
|
|
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
|
|
|
|
|
net = TrainingWrapper(net, opt, loss_scale)
|
|
|
|
|
|
|
|
|
|
if args_opt.pre_trained:
|
|
|
|
|
if args_opt.pre_trained_epoch_size <= 0:
|
|
|
|
|
raise KeyError("pre_trained_epoch_size must be greater than 0.")
|
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
|
|
|
|
|
lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size,
|
|
|
|
|
lr_init=0, lr_end=0, lr_max=args_opt.lr,
|
|
|
|
|
warmup_epochs=max(350 // 20, 1),
|
|
|
|
|
total_epochs=350,
|
|
|
|
|
steps_per_epoch=dataset_size))
|
|
|
|
|
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
|
|
|
|
|
net = TrainingWrapper(net, opt, loss_scale)
|
|
|
|
|
|
|
|
|
|
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
|
|
|
|
|
|
|
|
|
model = Model(net)
|
|
|
|
|