|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import argparse
|
|
|
|
|
import ast
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
@ -35,13 +36,13 @@ from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
|
|
|
|
parser = argparse.ArgumentParser(description='Image classification')
|
|
|
|
|
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
|
|
|
|
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
|
|
|
|
|
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
|
|
|
|
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
|
|
|
|
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
|
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
|
|
|
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
|
|
|
|
parser.add_argument('--parameter_server', type=bool, default=False, help='Run parameter server train')
|
|
|
|
|
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
random.seed(1)
|
|
|
|
|