@ -49,10 +49,10 @@ de.config.set_seed(1)
parser = argparse . ArgumentParser ( description = ' Image classification ' )
parser . add_argument ( ' --dataset_path ' , type = str , default = None , help = ' Dataset path ' )
parser . add_argument ( ' --pre_trained ' , type = str , default = None , help = ' Pretrained checkpoint path ' )
parser . add_argument ( ' --device_targe ' , type = str , default = None , help = ' run device_targe ' )
parser . add_argument ( ' --device_targe t ' , type = str , default = None , help = ' run device_targe t ' )
args_opt = parser . parse_args ( )
if args_opt . device_targe == " Ascend " :
if args_opt . device_targe t == " Ascend " :
device_id = int ( os . getenv ( ' DEVICE_ID ' ) )
rank_id = int ( os . getenv ( ' RANK_ID ' ) )
rank_size = int ( os . getenv ( ' RANK_SIZE ' ) )
@ -61,7 +61,7 @@ if args_opt.device_targe == "Ascend":
context . set_context ( mode = context . GRAPH_MODE ,
device_target = " Ascend " ,
device_id = device_id , save_graphs = False )
elif args_opt . device_targe == " GPU " :
elif args_opt . device_targe t == " GPU " :
context . set_context ( mode = context . GRAPH_MODE ,
device_target = " GPU " ,
save_graphs = False )
@ -161,13 +161,13 @@ class Monitor(Callback):
if __name__ == ' __main__ ' :
if args_opt . device_targe == " GPU " :
if args_opt . device_targe t == " GPU " :
# train on gpu
print ( " train args: " , args_opt )
print ( " cfg: " , config_gpu )
# define network
net = mobilenet_v2 ( num_classes = config_gpu . num_classes , device_targe = " GPU " )
net = mobilenet_v2 ( num_classes = config_gpu . num_classes , device_targe t = " GPU " )
# define loss
if config_gpu . label_smooth > 0 :
loss = CrossEntropyWithLabelSmooth ( smooth_factor = config_gpu . label_smooth ,
@ -179,7 +179,7 @@ if __name__ == '__main__':
dataset = create_dataset ( dataset_path = args_opt . dataset_path ,
do_train = True ,
config = config_gpu ,
device_targe = args_opt . device_targe ,
device_targe t = args_opt . device_targe t ,
repeat_num = 1 ,
batch_size = config_gpu . batch_size )
step_size = dataset . get_dataset_size ( )
@ -216,7 +216,7 @@ if __name__ == '__main__':
# begin train
model . train ( epoch_size , dataset , callbacks = cb )
print ( " ============== End Training ============== " )
elif args_opt . device_targe == " Ascend " :
elif args_opt . device_targe t == " Ascend " :
# train on ascend
print ( " train args: " , args_opt , " \n cfg: " , config_ascend ,
" \n parallel args: rank_id {} , device_id {} , rank_size {} " . format ( rank_id , device_id , rank_size ) )
@ -228,7 +228,7 @@ if __name__ == '__main__':
init ( )
epoch_size = config_ascend . epoch_size
net = mobilenet_v2 ( num_classes = config_ascend . num_classes , device_targe = " Ascend " )
net = mobilenet_v2 ( num_classes = config_ascend . num_classes , device_targe t = " Ascend " )
net . to_float ( mstype . float16 )
for _ , cell in net . cells_and_names ( ) :
if isinstance ( cell , nn . Dense ) :
@ -242,7 +242,7 @@ if __name__ == '__main__':
dataset = create_dataset ( dataset_path = args_opt . dataset_path ,
do_train = True ,
config = config_ascend ,
device_targe = args_opt . device_targe ,
device_targe t = args_opt . device_targe t ,
repeat_num = 1 ,
batch_size = config_ascend . batch_size )
step_size = dataset . get_dataset_size ( )
@ -276,4 +276,4 @@ if __name__ == '__main__':
cb + = [ ckpt_cb ]
model . train ( epoch_size , dataset , callbacks = cb )
else :
raise ValueError ( " Unsupported device_targe ." )
raise ValueError ( " Unsupported device_targe t ." )