|
|
|
@ -52,15 +52,15 @@ class ParameterReduce(nn.Cell):
|
|
|
|
|
def parse_args(cloud_args=None):
|
|
|
|
|
"""parse_args"""
|
|
|
|
|
parser = argparse.ArgumentParser('mindspore classification test')
|
|
|
|
|
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
|
|
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
|
|
|
|
help='device where the code will be implemented. (Default: Ascend)')
|
|
|
|
|
# dataset related
|
|
|
|
|
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="imagenet2012")
|
|
|
|
|
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
|
|
|
|
|
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
|
|
|
|
|
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
|
|
|
|
|
# network related
|
|
|
|
|
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
|
|
|
|
|
parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. '
|
|
|
|
|
parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. '
|
|
|
|
|
'If it is a direction, it will test all ckpt')
|
|
|
|
|
|
|
|
|
|
# logging related
|
|
|
|
@ -68,9 +68,6 @@ def parse_args(cloud_args=None):
|
|
|
|
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
|
|
|
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
|
|
|
|
|
|
|
|
|
# roma obs
|
|
|
|
|
parser.add_argument('--train_url', type=str, default="", help='train url')
|
|
|
|
|
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
args_opt = merge_args(args_opt, cloud_args)
|
|
|
|
|
|
|
|
|
@ -82,6 +79,8 @@ def parse_args(cloud_args=None):
|
|
|
|
|
args_opt.image_size = cfg.image_size
|
|
|
|
|
args_opt.num_classes = cfg.num_classes
|
|
|
|
|
args_opt.per_batch_size = cfg.batch_size
|
|
|
|
|
args_opt.momentum = cfg.momentum
|
|
|
|
|
args_opt.weight_decay = cfg.weight_decay
|
|
|
|
|
args_opt.buffer_size = cfg.buffer_size
|
|
|
|
|
args_opt.pad_mode = cfg.pad_mode
|
|
|
|
|
args_opt.padding = cfg.padding
|
|
|
|
@ -130,23 +129,23 @@ def test(cloud_args=None):
|
|
|
|
|
args.logger.save_args(args)
|
|
|
|
|
|
|
|
|
|
if args.dataset == "cifar10":
|
|
|
|
|
net = vgg16(num_classes=args.num_classes)
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
|
|
|
|
net = vgg16(num_classes=args.num_classes, args=args)
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum,
|
|
|
|
|
weight_decay=args.weight_decay)
|
|
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
|
|
|
|
|
|
|
|
|
param_dict = load_checkpoint(args.checkpoint_path)
|
|
|
|
|
param_dict = load_checkpoint(args.pre_trained)
|
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
net.set_train(False)
|
|
|
|
|
dataset = vgg_create_dataset(args.data_path, 1, False)
|
|
|
|
|
dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False)
|
|
|
|
|
res = model.eval(dataset)
|
|
|
|
|
print("result: ", res)
|
|
|
|
|
else:
|
|
|
|
|
# network
|
|
|
|
|
args.logger.important_info('start create network')
|
|
|
|
|
if os.path.isdir(args.pretrained):
|
|
|
|
|
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
|
|
|
|
|
if os.path.isdir(args.pre_trained):
|
|
|
|
|
models = list(glob.glob(os.path.join(args.pre_trained, '*.ckpt')))
|
|
|
|
|
print(models)
|
|
|
|
|
if args.graph_ckpt:
|
|
|
|
|
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
|
|
|
|
@ -154,14 +153,10 @@ def test(cloud_args=None):
|
|
|
|
|
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
|
|
|
|
|
args.models = sorted(models, key=f)
|
|
|
|
|
else:
|
|
|
|
|
args.models = [args.pretrained,]
|
|
|
|
|
args.models = [args.pre_trained,]
|
|
|
|
|
|
|
|
|
|
for model in args.models:
|
|
|
|
|
if args.dataset == "cifar10":
|
|
|
|
|
dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False)
|
|
|
|
|
else:
|
|
|
|
|
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size)
|
|
|
|
|
|
|
|
|
|
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size)
|
|
|
|
|
eval_dataloader = dataset.create_tuple_iterator()
|
|
|
|
|
network = vgg16(args.num_classes, args, phase="test")
|
|
|
|
|
|
|
|
|
|