modify for resnet readme

pull/5313/head
qujianwei 5 years ago
parent 8c2e4304d5
commit 2c6926fecc

File diff suppressed because it is too large Load Diff

@ -16,6 +16,7 @@
import os import os
import random import random
import argparse import argparse
import ast
import numpy as np import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
@ -35,13 +36,13 @@ from src.lr_generator import get_lr, warmup_cosine_annealing_lr
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012') parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--parameter_server', type=bool, default=False, help='Run parameter server train') parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
args_opt = parser.parse_args() args_opt = parser.parse_args()
random.seed(1) random.seed(1)

Loading…
Cancel
Save