|
|
|
@ -35,7 +35,7 @@ from src.config import set_config
|
|
|
|
|
|
|
|
|
|
from src.args import train_parse_args
|
|
|
|
|
from src.utils import context_device_init, switch_precision, config_ckpoint
|
|
|
|
|
from src.models import CrossEntropyWithLabelSmooth, define_net
|
|
|
|
|
from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt
|
|
|
|
|
|
|
|
|
|
set_seed(1)
|
|
|
|
|
|
|
|
|
@ -50,7 +50,18 @@ if __name__ == '__main__':
|
|
|
|
|
context_device_init(config)
|
|
|
|
|
|
|
|
|
|
# define network
|
|
|
|
|
backbone_net, head_net, net = define_net(args_opt, config)
|
|
|
|
|
backbone_net, head_net, net = define_net(config)
|
|
|
|
|
|
|
|
|
|
# load the ckpt file to the network for fine tune or incremental leaning
|
|
|
|
|
if args_opt.pretrain_ckpt:
|
|
|
|
|
if args_opt.train_method == "fine_tune":
|
|
|
|
|
load_ckpt(net, args_opt.pretrain_ckpt)
|
|
|
|
|
elif args_opt.train_method == "incremental_learn":
|
|
|
|
|
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
|
|
|
|
elif args_opt.train_method == "train":
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")
|
|
|
|
|
|
|
|
|
|
# CPU only support "incremental_learn"
|
|
|
|
|
if args_opt.train_method == "incremental_learn":
|
|
|
|
@ -60,7 +71,11 @@ if __name__ == '__main__':
|
|
|
|
|
elif args_opt.train_method in ("train", "fine_tune"):
|
|
|
|
|
if args_opt.platform == "CPU":
|
|
|
|
|
raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".")
|
|
|
|
|
dataset, step_size = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
|
|
|
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
|
|
|
|
|
step_size = dataset.get_dataset_size()
|
|
|
|
|
if step_size == 0:
|
|
|
|
|
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
|
|
|
|
|
than batch_size in config.py")
|
|
|
|
|
|
|
|
|
|
# Currently, only Ascend support switch precision.
|
|
|
|
|
switch_precision(net, mstype.float16, config)
|
|
|
|
|