diff --git a/model_zoo/official/cv/alexnet/eval.py b/model_zoo/official/cv/alexnet/eval.py index 6a091aedd8..b8d7a87c36 100644 --- a/model_zoo/official/cv/alexnet/eval.py +++ b/model_zoo/official/cv/alexnet/eval.py @@ -18,6 +18,7 @@ eval alexnet according to model file: python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt """ +import ast import argparse from src.config import alexnet_cfg as cfg from src.dataset import create_dataset_cifar10 @@ -36,7 +37,8 @@ if __name__ == "__main__": parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=True, help='dataset_sink_mode is False or True') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py index 4512244b92..37d7ca1b60 100644 --- a/model_zoo/official/cv/alexnet/train.py +++ b/model_zoo/official/cv/alexnet/train.py @@ -18,6 +18,7 @@ train alexnet and get network model files(.ckpt) : python train.py --data_path /YourDataPath """ +import ast import argparse from src.config import alexnet_cfg as cfg from src.dataset import create_dataset_cifar10 @@ -38,7 +39,8 @@ if __name__ == "__main__": parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=True, help='dataset_sink_mode is False or True') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) diff --git a/model_zoo/official/cv/lenet/eval.py b/model_zoo/official/cv/lenet/eval.py index bcd5503c39..4083a06400 100644 --- a/model_zoo/official/cv/lenet/eval.py +++ b/model_zoo/official/cv/lenet/eval.py @@ -19,6 +19,7 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt """ import os +import ast import argparse import mindspore.nn as nn from mindspore import context @@ -37,7 +38,8 @@ if __name__ == "__main__": help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=False, help='dataset_sink_mode is False or True') args = parser.parse_args() diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index 2c45c5b327..7cd379134a 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -19,6 +19,7 @@ python train.py --data_path /YourDataPath """ import os +import ast import argparse from src.config import mnist_cfg as cfg from src.dataset import create_dataset @@ -38,7 +39,8 @@ if __name__ == "__main__": help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ path where the trained ckpt file') - parser.add_argument('--dataset_sink_mode', type=bool, default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, default=True, + help='dataset_sink_mode is False or True') args = parser.parse_args()