|
|
|
@ -53,30 +53,23 @@ if __name__ == '__main__':
|
|
|
|
|
# define network
|
|
|
|
|
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
|
|
|
|
|
|
|
|
|
# load the ckpt file to the network for fine tune or incremental leaning
|
|
|
|
|
if args_opt.pretrain_ckpt:
|
|
|
|
|
if args_opt.train_method == "fine_tune":
|
|
|
|
|
load_ckpt(net, args_opt.pretrain_ckpt)
|
|
|
|
|
elif args_opt.train_method == "incremental_learn":
|
|
|
|
|
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
|
|
|
|
elif args_opt.train_method == "train":
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")
|
|
|
|
|
|
|
|
|
|
# CPU only support "incremental_learn"
|
|
|
|
|
if args_opt.train_method == "incremental_learn":
|
|
|
|
|
if args_opt.pretrain_ckpt and args_opt.freeze_layer == "backbone":
|
|
|
|
|
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
|
|
|
|
step_size = extract_features(backbone_net, args_opt.dataset_path, config)
|
|
|
|
|
net = head_net
|
|
|
|
|
|
|
|
|
|
elif args_opt.train_method in ("train", "fine_tune"):
|
|
|
|
|
else:
|
|
|
|
|
if args_opt.platform == "CPU":
|
|
|
|
|
raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".")
|
|
|
|
|
raise ValueError("CPU only support fine tune the head net, doesn't support fine tune the all net")
|
|
|
|
|
|
|
|
|
|
if args_opt.pretrain_ckpt:
|
|
|
|
|
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
|
|
|
|
|
|
|
|
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
|
|
|
|
|
step_size = dataset.get_dataset_size()
|
|
|
|
|
if step_size == 0:
|
|
|
|
|
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
|
|
|
|
|
than batch_size in config.py")
|
|
|
|
|
|
|
|
|
|
if step_size == 0:
|
|
|
|
|
raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \
|
|
|
|
|
than batch_size in config.py")
|
|
|
|
|
|
|
|
|
|
# Currently, only Ascend support switch precision.
|
|
|
|
|
switch_precision(net, mstype.float16, config)
|
|
|
|
@ -99,15 +92,32 @@ if __name__ == '__main__':
|
|
|
|
|
total_epochs=epoch_size,
|
|
|
|
|
steps_per_epoch=step_size))
|
|
|
|
|
|
|
|
|
|
if args_opt.train_method == "incremental_learn":
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
|
|
|
|
|
if args_opt.pretrain_ckpt is None or args_opt.freeze_layer == "none":
|
|
|
|
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
|
|
|
|
|
config.weight_decay, config.loss_scale)
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
|
|
|
|
|
|
|
|
|
|
cb = config_ckpoint(config, lr, step_size)
|
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
|
model.train(epoch_size, dataset, callbacks=cb)
|
|
|
|
|
print("============== End Training ==============")
|
|
|
|
|
|
|
|
|
|
network = WithLossCell(net, loss)
|
|
|
|
|
else:
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()), lr, config.momentum, config.weight_decay)
|
|
|
|
|
|
|
|
|
|
network = WithLossCell(head_net, loss)
|
|
|
|
|
network = TrainOneStepCell(network, opt)
|
|
|
|
|
network.set_train()
|
|
|
|
|
|
|
|
|
|
features_path = args_opt.dataset_path + '_features'
|
|
|
|
|
idx_list = list(range(step_size))
|
|
|
|
|
rank = 0
|
|
|
|
|
if config.run_distribute:
|
|
|
|
|
rank = get_rank()
|
|
|
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
|
|
|
|
if not os.path.isdir(save_ckpt_path):
|
|
|
|
|
os.mkdir(save_ckpt_path)
|
|
|
|
|
|
|
|
|
|
for epoch in range(epoch_size):
|
|
|
|
|
random.shuffle(idx_list)
|
|
|
|
@ -119,24 +129,8 @@ if __name__ == '__main__':
|
|
|
|
|
losses.append(network(feature, label).asnumpy())
|
|
|
|
|
epoch_mseconds = (time.time()-epoch_start) * 1000
|
|
|
|
|
per_step_mseconds = epoch_mseconds / step_size
|
|
|
|
|
print("epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
|
|
|
|
|
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))))
|
|
|
|
|
print("epoch[{}/{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
|
|
|
|
|
.format(epoch + 1, epoch_size, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))))
|
|
|
|
|
if (epoch + 1) % config.save_checkpoint_epochs == 0:
|
|
|
|
|
rank = 0
|
|
|
|
|
if config.run_distribute:
|
|
|
|
|
rank = get_rank()
|
|
|
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
|
|
|
|
save_checkpoint(network, os.path.join(save_ckpt_path, \
|
|
|
|
|
f"mobilenetv2_head_{epoch+1}.ckpt"))
|
|
|
|
|
save_checkpoint(net, os.path.join(save_ckpt_path, f"mobilenetv2_{epoch+1}.ckpt"))
|
|
|
|
|
print("total cost {:5.4f} s".format(time.time() - start))
|
|
|
|
|
|
|
|
|
|
elif args_opt.train_method in ("train", "fine_tune"):
|
|
|
|
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
|
|
|
|
|
config.weight_decay, config.loss_scale)
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
|
|
|
|
|
|
|
|
|
|
cb = config_ckpoint(config, lr, step_size)
|
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
|
model.train(epoch_size, dataset, callbacks=cb)
|
|
|
|
|
print("============== End Training ==============")
|
|
|
|
|