|
|
|
@ -57,6 +57,24 @@ def test_imagefolder_numsamples():
|
|
|
|
|
logger.info("Number of data in data1: {}".format(num_iter))
|
|
|
|
|
assert num_iter == 10
|
|
|
|
|
|
|
|
|
|
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
|
|
|
|
|
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
|
|
|
|
|
|
|
|
|
num_iter = 0
|
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
|
num_iter += 1
|
|
|
|
|
|
|
|
|
|
assert num_iter == 3
|
|
|
|
|
|
|
|
|
|
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
|
|
|
|
|
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
|
|
|
|
|
|
|
|
|
num_iter = 0
|
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
|
num_iter += 1
|
|
|
|
|
|
|
|
|
|
assert num_iter == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_imagefolder_numshards():
|
|
|
|
|
logger.info("Test Case numShards")
|
|
|
|
|