|
|
|
@ -22,7 +22,6 @@ import mindspore.dataset.engine as de
|
|
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
|
|
|
|
import mindspore.dataset.transforms.c_transforms as C2
|
|
|
|
|
import mindspore.dataset.transforms.vision.py_transforms as P
|
|
|
|
|
from src.config import config_ascend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
|
|
|
|
@ -42,7 +41,7 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
|
|
|
|
|
rank_size = int(os.getenv("RANK_SIZE"))
|
|
|
|
|
rank_id = int(os.getenv("RANK_ID"))
|
|
|
|
|
columns_list = ['image', 'label']
|
|
|
|
|
if config_ascend.data_load_mode == "mindrecord":
|
|
|
|
|
if config.data_load_mode == "mindrecord":
|
|
|
|
|
load_func = partial(de.MindDataset, dataset_path, columns_list)
|
|
|
|
|
else:
|
|
|
|
|
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
|
|
|
|
@ -54,6 +53,13 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
|
|
|
|
|
num_shards=rank_size, shard_id=rank_id)
|
|
|
|
|
else:
|
|
|
|
|
ds = load_func(num_parallel_workers=8, shuffle=False)
|
|
|
|
|
elif device_target == "GPU":
|
|
|
|
|
if do_train:
|
|
|
|
|
from mindspore.communication.management import get_rank, get_group_size
|
|
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
|
|
|
|
|
num_shards=get_group_size(), shard_id=get_rank())
|
|
|
|
|
else:
|
|
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Unsupport device_target.")
|
|
|
|
|
|
|
|
|
|