|
|
|
@ -501,6 +501,35 @@ def test_cifar_exception_file_path():
|
|
|
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cifar10_pk_sampler_get_dataset_size():
|
|
|
|
|
"""
|
|
|
|
|
Test Cifar10Dataset with PKSampler and get_dataset_size
|
|
|
|
|
"""
|
|
|
|
|
sampler = ds.PKSampler(3)
|
|
|
|
|
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
|
|
|
|
|
num_iter = 0
|
|
|
|
|
ds_sz = data.get_dataset_size()
|
|
|
|
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
|
|
|
num_iter += 1
|
|
|
|
|
|
|
|
|
|
assert ds_sz == num_iter == 30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cifar10_with_chained_sampler_get_dataset_size():
|
|
|
|
|
"""
|
|
|
|
|
Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
|
|
|
|
|
"""
|
|
|
|
|
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
|
|
|
|
|
child_sampler = ds.PKSampler(4)
|
|
|
|
|
sampler.add_child(child_sampler)
|
|
|
|
|
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
|
|
|
|
|
num_iter = 0
|
|
|
|
|
ds_sz = data.get_dataset_size()
|
|
|
|
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
|
|
|
num_iter += 1
|
|
|
|
|
assert ds_sz == num_iter == 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_cifar10_content_check()
|
|
|
|
|
test_cifar10_basic()
|
|
|
|
@ -517,3 +546,6 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
test_cifar_usage()
|
|
|
|
|
test_cifar_exception_file_path()
|
|
|
|
|
|
|
|
|
|
test_cifar10_with_chained_sampler_get_dataset_size()
|
|
|
|
|
test_cifar10_pk_sampler_get_dataset_size()
|
|
|
|
|