|
|
|
@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|
|
|
|
|
|
|
|
|
device_num = int(os.getenv("RANK_SIZE"))
|
|
|
|
|
rank_id = int(os.getenv("RANK_ID"))
|
|
|
|
|
if device_num == 1:
|
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
|
|
|
|
if do_train:
|
|
|
|
|
if device_num == 1:
|
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
|
|
|
|
else:
|
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
|
else:
|
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
|
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
|
|
|
|
|
|
image_size = 224
|
|
|
|
|