@ -633,9 +633,9 @@ class Dataset:
Datasets of size f1 * K , f2 * K , … , fn * K ( rounded to nearest integer ) where K is the size
of the original dataset . If after rounding , any size equals 0 , an error will occur .
All floats must be between 0 and 1 and must sum to 1 , otherwise an error will occur .
randomize ( bool ): determines whether or not to split the data randomly . If true , the data
will be randomly split . Otherwise , each split will be created with consecutive rows
from the dataset .
randomize ( bool , optional ): determines whether or not to split the data randomly ( default = True ) .
If true , the data will be randomly split . Otherwise , each split will be created with
consecutive rows from the dataset .
Note :
1. Dataset cannot be sharded if split is going to be called .
@ -678,7 +678,8 @@ class Dataset:
ds = copy . deepcopy ( self )
if randomize :
# want to shuffle the same way every epoch before split
ds = ds . shuffle ( )
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
ds = ds . shuffle ( 10000 )
ds . reshuffle_each_epoch = False
if rows_to_skip > 0 :
@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
>> > new_sampler = ds . DistributedSampler ( 10 , 2 )
>> > data . use_sampler ( new_sampler )
"""
if new_sampler is not None and not isinstance ( new_sampler , ( samplers . BuiltinSampler , samplers . Sampler ) ) :
raise TypeError ( " new_sampler is not an instance of a sampler. " )
self . sampler = self . sampler . child_sampler
self . add_sampler ( new_sampler )
@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
def is_sharded ( self ) :
raise NotImplementedError ( " MappableDataset must implement is_sharded. " )
def _get_sampler_dataset_size ( self ) :
if self . sampler is not None :
return self . sampler . get_dataset_size ( )
return None
@check_split
def split ( self , sizes , randomize = True ) :
@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
Datasets of size f1 * K , f2 * K , … , fn * K ( rounded to nearest integer ) where K is the size
of the original dataset . If after rounding , any size equals 0 , an error will occur .
All floats must be between 0 and 1 and must sum to 1 , otherwise an error will occur .
randomize ( bool ): determines whether or not to split the data randomly . If true , the data
will be randomly split . Otherwise , each split will be created with consecutive rows
from the dataset .
randomize ( bool , optional ): determines whether or not to split the data randomly ( default = True ) .
If true , the data will be randomly split . Otherwise , each split will be created with
consecutive rows from the dataset .
Note :
1. Dataset should not be sharded if split is going to be called . Instead , create a
@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
self . iterator = TupleIterator ( self )
class RangeDataset ( MappableDataset ) :
"""
A source dataset that reads and parses datasets stored on disk in a range .
@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
else :
num_samples = self . num_samples
num_rows = ImageFolderOp . get_num_rows_and_classes ( self . dataset_dir , num_samples ) [ 0 ]
rows_per_shard = get_num_rows ( num_rows , self . num_shards )
rows_from_sampler = self . _get_sampler_dataset_size ( )
return get_num_rows ( num_rows , self . num_shards )
if rows_from_sampler is None :
return rows_per_shard
return min ( rows_from_sampler , rows_per_shard )
def num_classes ( self ) :
"""
@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
num_samples = self . num_samples
num_rows = MnistOp . get_num_rows ( self . dataset_dir , num_samples )
rows_per_shard = get_num_rows ( num_rows , self . num_shards )
rows_from_sampler = self . _get_sampler_dataset_size ( )
if rows_from_sampler is None :
return rows_per_shard
return get_num_rows ( num_rows , self . num_shards )
return min ( rows_from_sampler , rows_per_shard )
def is_shuffled ( self ) :
if self . shuffle_level is None :
@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset):
Return :
Number , number of batches .
"""
return self . _dataset_size
rows_from_sampler = self . _get_sampler_dataset_size ( )
if rows_from_sampler is None :
return self . _dataset_size
return min ( rows_from_sampler , self . _dataset_size )
# manually set dataset_size as a temporary solution.
def set_dataset_size ( self , value ) :
@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
class_indexing = self . class_indexing
num_rows = ManifestOp . get_num_rows_and_classes ( self . dataset_file , num_samples , class_indexing , self . usage ) [ 0 ]
rows_per_shard = get_num_rows ( num_rows , self . num_shards )
rows_from_sampler = self . _get_sampler_dataset_size ( )
if rows_from_sampler is None :
return rows_per_shard
return get_num_rows ( num_rows , self . num_shards )
return min ( rows_from_sampler , rows_per_shard )
def num_classes ( self ) :
"""
@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
num_samples = self . num_samples
num_rows = CifarOp . get_num_rows ( self . dataset_dir , num_samples , True )
rows_per_shard = get_num_rows ( num_rows , self . num_shards )
rows_from_sampler = self . _get_sampler_dataset_size ( )
return get_num_rows ( num_rows , self . num_shards )
if rows_from_sampler is None :
return rows_per_shard
return min ( rows_from_sampler , rows_per_shard )
def is_shuffled ( self ) :
if self . shuffle_level is None :
@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
num_samples = self . num_samples
num_rows = CifarOp . get_num_rows ( self . dataset_dir , num_samples , False )
rows_per_shard = get_num_rows ( num_rows , self . num_shards )
rows_from_sampler = self . _get_sampler_dataset_size ( )
if rows_from_sampler is None :
return rows_per_shard
return get_num_rows ( num_rows , self . num_shards )
return min ( rows_from_sampler , rows_per_shard )
def is_shuffled ( self ) :
if self . shuffle_level is None :
@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
Return :
Number , number of batches .
"""
return num_samples
rows_from_sampler = self . _get_sampler_dataset_size ( )
if rows_from_sampler is None :
return self . num_samples
return min ( rows_from_sampler , self . num_samples )
def is_shuffled ( self ) :
return True
@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset):
Return :
Number , number of batches .
"""
return self . num_samples
rows_from_sampler = self . _get_sampler_dataset_size ( )
if rows_from_sampler is None :
return self . num_samples
return min ( rows_from_sampler , self . num_samples )
def get_class_indexing ( self ) :
"""