!5471 Remove bool type arg of script

Merge pull request !5471 from chenfei_mindspore/rm-bool-arg-of-script
pull/5471/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0f28998969

@ -38,8 +38,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="", parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file') help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
@ -67,5 +65,5 @@ if __name__ == "__main__":
raise ValueError("Load param into net fail!") raise ValueError("Load param into net fail!")
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== {} ==============".format(acc)) print("============== {} ==============".format(acc))

@ -36,8 +36,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="", parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file') help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":

@ -41,8 +41,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="", parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file') help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
@ -76,5 +74,5 @@ if __name__ == "__main__":
print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode) dataset_sink_mode=True)
print("============== End Training ==============") print("============== End Training ==============")

@ -32,7 +32,6 @@ parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target') parser.add_argument('--device_target', type=str, default=None, help='Run device target')
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
args_opt = parser.parse_args() args_opt = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
@ -51,9 +50,8 @@ if __name__ == '__main__':
# define fusion network # define fusion network
network = mobilenetV2(num_classes=config_device_target.num_classes) network = mobilenetV2(num_classes=config_device_target.num_classes)
if args_opt.quantization_aware: # convert fusion network to quantization aware network
# convert fusion network to quantization aware network network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# define network loss # define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')

@ -50,5 +50,4 @@ python ${BASEPATH}/../eval.py \
--device_target=$1 \ --device_target=$1 \
--dataset_path=$2 \ --dataset_path=$2 \
--checkpoint_path=$3 \ --checkpoint_path=$3 \
--quantization_aware=True \
&> infer.log & # dataset val folder path &> infer.log & # dataset val folder path

Loading…
Cancel
Save