|
|
|
@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain():
|
|
|
|
|
# Use 1 statement to add child sampler
|
|
|
|
|
np_data = [1, 2, 3, 4]
|
|
|
|
|
sampler = ds.SequentialSampler(start_index=1, num_samples=2)
|
|
|
|
|
sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
|
|
|
|
|
sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
|
|
|
|
|
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
|
|
|
|
|
|
|
|
|
|
# Verify dataset size
|
|
|
|
|
data1_size = data1.get_dataset_size()
|
|
|
|
|
logger.info("dataset size is: {}".format(data1_size))
|
|
|
|
|
assert data1_size == 4
|
|
|
|
|
assert data1_size == 1
|
|
|
|
|
|
|
|
|
|
# Verify number of rows
|
|
|
|
|
assert sum([1 for _ in data1]) == 4
|
|
|
|
|
assert sum([1 for _ in data1]) == 1
|
|
|
|
|
|
|
|
|
|
# Verify dataset contents
|
|
|
|
|
res = []
|
|
|
|
|