|
|
|
@ -633,9 +633,9 @@ class Dataset:
|
|
|
|
|
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
|
|
|
|
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
|
|
|
|
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
|
|
|
|
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
|
|
|
|
will be randomly split. Otherwise, each split will be created with consecutive rows
|
|
|
|
|
from the dataset.
|
|
|
|
|
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
|
|
|
|
If true, the data will be randomly split. Otherwise, each split will be created with
|
|
|
|
|
consecutive rows from the dataset.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
1. Dataset cannot be sharded if split is going to be called.
|
|
|
|
@ -678,7 +678,8 @@ class Dataset:
|
|
|
|
|
ds = copy.deepcopy(self)
|
|
|
|
|
if randomize:
|
|
|
|
|
# want to shuffle the same way every epoch before split
|
|
|
|
|
ds = ds.shuffle()
|
|
|
|
|
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
|
|
|
|
|
ds = ds.shuffle(10000)
|
|
|
|
|
ds.reshuffle_each_epoch = False
|
|
|
|
|
|
|
|
|
|
if rows_to_skip > 0:
|
|
|
|
@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
|
|
|
|
|
>>> new_sampler = ds.DistributedSampler(10, 2)
|
|
|
|
|
>>> data.use_sampler(new_sampler)
|
|
|
|
|
"""
|
|
|
|
|
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
|
|
|
|
raise TypeError("new_sampler is not an instance of a sampler.")
|
|
|
|
|
|
|
|
|
|
self.sampler = self.sampler.child_sampler
|
|
|
|
|
self.add_sampler(new_sampler)
|
|
|
|
|
|
|
|
|
@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
|
|
|
|
|
def is_sharded(self):
|
|
|
|
|
raise NotImplementedError("MappableDataset must implement is_sharded.")
|
|
|
|
|
|
|
|
|
|
def _get_sampler_dataset_size(self):
|
|
|
|
|
if self.sampler is not None:
|
|
|
|
|
return self.sampler.get_dataset_size()
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@check_split
|
|
|
|
|
def split(self, sizes, randomize=True):
|
|
|
|
@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
|
|
|
|
|
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
|
|
|
|
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
|
|
|
|
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
|
|
|
|
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
|
|
|
|
will be randomly split. Otherwise, each split will be created with consecutive rows
|
|
|
|
|
from the dataset.
|
|
|
|
|
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
|
|
|
|
If true, the data will be randomly split. Otherwise, each split will be created with
|
|
|
|
|
consecutive rows from the dataset.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
1. Dataset should not be sharded if split is going to be called. Instead, create a
|
|
|
|
@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
|
|
|
|
|
self.iterator = TupleIterator(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RangeDataset(MappableDataset):
|
|
|
|
|
"""
|
|
|
|
|
A source dataset that reads and parses datasets stored on disk in a range.
|
|
|
|
@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
|
|
|
|
|
else:
|
|
|
|
|
num_samples = self.num_samples
|
|
|
|
|
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
return get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
|
|
def num_classes(self):
|
|
|
|
|
"""
|
|
|
|
@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
|
|
|
|
|
num_samples = self.num_samples
|
|
|
|
|
|
|
|
|
|
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
|
|
return get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
|
if self.shuffle_level is None:
|
|
|
|
@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset):
|
|
|
|
|
Return:
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
"""
|
|
|
|
|
return self._dataset_size
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return self._dataset_size
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, self._dataset_size)
|
|
|
|
|
|
|
|
|
|
# manually set dataset_size as a temporary solution.
|
|
|
|
|
def set_dataset_size(self, value):
|
|
|
|
@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
|
|
|
|
|
class_indexing = self.class_indexing
|
|
|
|
|
|
|
|
|
|
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
|
|
return get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
|
|
def num_classes(self):
|
|
|
|
|
"""
|
|
|
|
@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
|
|
|
|
|
num_samples = self.num_samples
|
|
|
|
|
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
return get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
|
if self.shuffle_level is None:
|
|
|
|
@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
|
|
|
|
|
num_samples = self.num_samples
|
|
|
|
|
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
|
|
return get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
|
if self.shuffle_level is None:
|
|
|
|
@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
|
|
|
|
|
Return:
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
"""
|
|
|
|
|
return num_samples
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return self.num_samples
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, self.num_samples)
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
|
return True
|
|
|
|
@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset):
|
|
|
|
|
Return:
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
"""
|
|
|
|
|
return self.num_samples
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return self.num_samples
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, self.num_samples)
|
|
|
|
|
|
|
|
|
|
def get_class_indexing(self):
|
|
|
|
|
"""
|
|
|
|
|