|
|
@ -85,10 +85,10 @@ def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
|
|
|
|
device_num, rank_id = _get_rank_info()
|
|
|
|
device_num, rank_id = _get_rank_info()
|
|
|
|
|
|
|
|
|
|
|
|
if device_num == 1:
|
|
|
|
if device_num == 1:
|
|
|
|
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
|
|
|
|
|
|
|
|
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
|
|
|
|
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
|
|
|
|
image_size = imagenet_cfg.image_height
|
|
|
|
image_size = imagenet_cfg.image_height
|
|
|
|