|
|
|
@ -24,6 +24,7 @@ from mindspore.model_zoo.mobilenet import mobilenet_v2
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Image classification')
|
|
|
|
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
|
|
|
@ -39,7 +40,8 @@ context.set_context(enable_mem_reuse=True)
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
|
|
|
|
|
net = mobilenet_v2()
|
|
|
|
|
net = mobilenet_v2(num_classes=config.num_classes)
|
|
|
|
|
net.to_float(mstype.float16)
|
|
|
|
|
|
|
|
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
|
|
|
|
|
step_size = dataset.get_dataset_size()
|
|
|
|
|