fixed bug for split, RandomSampler and some other cleanup

add another test case

typo

merge conflict

another PR changed testing behavior, updated test cases in this commit

added input check for use_sampler

addressed code review comments

fixed pylint, not related to my changes

fixed edge case of rounding in getting split sizes

fix pylint
pull/1638/head
Peilin Wang 5 years ago
parent 6420f7248f
commit 5469be2a97

@ -609,7 +609,20 @@ class Dataset:
absolute_sizes.append(absolute_size)
absolute_sizes_sum = sum(absolute_sizes)
if absolute_sizes_sum != dataset_size:
# if we still need more rows, give them to the first split.
# if we have too many rows, remove the extras from the first split that has
# enough rows.
size_difference = dataset_size - absolute_sizes_sum
if size_difference > 0:
absolute_sizes[0] += size_difference
else:
for i, _ in enumerate(absolute_sizes):
if absolute_sizes[i] + size_difference > 0:
absolute_sizes[i] += size_difference
break
if sum(absolute_sizes) != dataset_size:
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
.format(absolute_sizes_sum, dataset_size))
@ -629,10 +642,15 @@ class Dataset:
provided, the dataset will be split into n datasets of size s1, size s2, , size sn
respectively. If the sum of all sizes does not equal the original dataset size, an
an error will occur.
If a list of floats [f1, f2, , fn] is provided, the dataset will be split into n
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
If a list of floats [f1, f2, , fn] is provided, all floats must be between 0 and 1
and must sum to 1, otherwise an error will occur. The dataset will be split into n
Datasets of size round(f1*K), round(f2*K), , round(fn*K) where K is the size of the
original dataset.
If after rounding:
-Any size equals 0, an error will occur.
-The sum of split sizes < K, the difference will be added to the first split.
-The sum of split sizes > K, the difference will be removed from the first large
enough split such that it will have atleast 1 row after removing the difference.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
@ -1212,7 +1230,7 @@ class MappableDataset(SourceDataset):
>>> data.use_sampler(new_sampler)
"""
if new_sampler is None:
raise TypeError("Input sampler could not be None.")
raise TypeError("Input sampler can not be None.")
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("Input sampler is not an instance of a sampler.")
@ -1247,10 +1265,15 @@ class MappableDataset(SourceDataset):
provided, the dataset will be split into n datasets of size s1, size s2, , size sn
respectively. If the sum of all sizes does not equal the original dataset size, an
an error will occur.
If a list of floats [f1, f2, , fn] is provided, the dataset will be split into n
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
If a list of floats [f1, f2, , fn] is provided, all floats must be between 0 and 1
and must sum to 1, otherwise an error will occur. The dataset will be split into n
Datasets of size round(f1*K), round(f2*K), , round(fn*K) where K is the size of the
original dataset.
If after rounding:
-Any size equals 0, an error will occur.
-The sum of split sizes < K, the difference will be added to the first split.
-The sum of split sizes > K, the difference will be removed from the first large
enough split such that it will have atleast 1 row after removing the difference.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.

@ -554,6 +554,43 @@ def test_mappable_multi_split():
assert s2_output == [2]
def test_rounding():
d = ds.ManifestDataset(manifest_file, shuffle=False)
# under rounding
s1, s2 = d.split([0.5, 0.5], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2]
assert s2_output == [3, 4]
# over rounding
s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s3_output = []
for item in s3.create_dict_iterator():
s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0]
assert s2_output == [1, 2]
assert s3_output == [3, 4]
if __name__ == '__main__':
test_unmappable_invalid_input()
test_unmappable_split()
@ -569,3 +606,4 @@ if __name__ == '__main__':
test_mappable_sharding()
test_mappable_get_dataset_size()
test_mappable_multi_split()
test_rounding()

Loading…
Cancel
Save