|
|
|
@ -177,13 +177,23 @@ def test_subset_sampler():
|
|
|
|
|
def pipeline():
|
|
|
|
|
sampler = ds.SubsetSampler(indices, num_samples)
|
|
|
|
|
data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler)
|
|
|
|
|
data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples)
|
|
|
|
|
dataset_size = data.get_dataset_size()
|
|
|
|
|
return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
|
|
|
|
|
dataset_size2 = data.get_dataset_size()
|
|
|
|
|
res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
|
|
|
|
|
res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2
|
|
|
|
|
return res1, res2
|
|
|
|
|
|
|
|
|
|
if exception_msg is None:
|
|
|
|
|
res, size = pipeline()
|
|
|
|
|
res, res2 = pipeline()
|
|
|
|
|
res, size = res
|
|
|
|
|
res2, size2 = res2
|
|
|
|
|
if not isinstance(indices, list):
|
|
|
|
|
indices = list(indices)
|
|
|
|
|
assert indices[:num_samples] == res
|
|
|
|
|
assert len(indices[:num_samples]) == size
|
|
|
|
|
assert indices[:num_samples] == res2
|
|
|
|
|
assert len(indices[:num_samples]) == size2
|
|
|
|
|
else:
|
|
|
|
|
with pytest.raises(Exception) as error_info:
|
|
|
|
|
pipeline()
|
|
|
|
@ -205,6 +215,8 @@ def test_subset_sampler():
|
|
|
|
|
test_config([0, 9, 3, 2], num_samples=2)
|
|
|
|
|
test_config([0, 9, 3, 2], num_samples=5)
|
|
|
|
|
|
|
|
|
|
test_config(np.array([1, 2, 3]))
|
|
|
|
|
|
|
|
|
|
test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]")
|
|
|
|
|
test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]")
|
|
|
|
|
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
|
|
|
|
@ -212,6 +224,9 @@ def test_subset_sampler():
|
|
|
|
|
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
|
|
|
|
|
test_config([0, 9, 3, 2], num_samples=-1,
|
|
|
|
|
exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)")
|
|
|
|
|
test_config(np.array([[1], [5]]), num_samples=10,
|
|
|
|
|
exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1],"
|
|
|
|
|
" type: <class 'numpy.ndarray'>.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sampler_chain():
|
|
|
|
@ -291,8 +306,8 @@ def test_sampler_list():
|
|
|
|
|
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
|
|
|
|
|
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
|
|
|
|
|
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
|
|
|
|
|
bad_pipeline(sampler=np.array([1, 2]),
|
|
|
|
|
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")
|
|
|
|
|
bad_pipeline(sampler=np.array([[1, 2]]),
|
|
|
|
|
msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|