diff --git a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh index e4d41ac9a2..c7a3d00863 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh +++ b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh @@ -75,15 +75,15 @@ run_gpu() python ${BASEPATH}/../train.py \ --dataset_path=$4 \ --device_target=$1 \ - --quantization_aware=True \ - &> ../train.log & # dataset train folder + --pre_trained=$5 \ + --quantization_aware=True &> ../train.log & # dataset train folder } -if [ $# -gt 6 ] || [ $# -lt 4 ] +if [ $# -gt 6 ] || [ $# -lt 5 ] then echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + Ascend: sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + GPU: sh run_train_quant.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ " exit 1 fi diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/dataset.py b/model_zoo/official/cv/mobilenetv2_quant/src/dataset.py index 105a5e1396..ce9f2d8faf 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/src/dataset.py +++ b/model_zoo/official/cv/mobilenetv2_quant/src/dataset.py @@ -22,7 +22,6 @@ import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.vision.py_transforms as P -from src.config import config_ascend def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): @@ -42,7 +41,7 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, rank_size = int(os.getenv("RANK_SIZE")) rank_id = int(os.getenv("RANK_ID")) columns_list = ['image', 'label'] - if config_ascend.data_load_mode == "mindrecord": + if config.data_load_mode == "mindrecord": load_func = partial(de.MindDataset, dataset_path, columns_list) else: load_func = partial(de.ImageFolderDatasetV2, dataset_path) @@ -54,6 +53,13 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, num_shards=rank_size, shard_id=rank_id) else: ds = load_func(num_parallel_workers=8, shuffle=False) + elif device_target == "GPU": + if do_train: + from mindspore.communication.management import get_rank, get_group_size + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=get_group_size(), shard_id=get_rank()) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) else: raise ValueError("Unsupport device_target.") diff --git a/model_zoo/official/cv/mobilenetv2_quant/train.py b/model_zoo/official/cv/mobilenetv2_quant/train.py index 2253fbb5a0..bd8051cccc 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/train.py +++ b/model_zoo/official/cv/mobilenetv2_quant/train.py @@ -56,7 +56,7 @@ if args_opt.device_target == "Ascend": context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) -elif args_opt.platform == "GPU": +elif args_opt.device_target == "GPU": init("nccl") context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, @@ -205,5 +205,5 @@ def train_on_gpu(): if __name__ == '__main__': if args_opt.device_target == "Ascend": train_on_ascend() - elif args_opt.platform == "GPU": + elif args_opt.device_target == "GPU": train_on_gpu()