|
|
|
@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset):
|
|
|
|
|
>>> new_sampler = ds.DistributedSampler(10, 2)
|
|
|
|
|
>>> data.use_sampler(new_sampler)
|
|
|
|
|
"""
|
|
|
|
|
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
|
|
|
|
raise TypeError("new_sampler is not an instance of a sampler.")
|
|
|
|
|
if new_sampler is None:
|
|
|
|
|
raise TypeError("Input sampler could not be None.")
|
|
|
|
|
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
|
|
|
|
raise TypeError("Input sampler is not an instance of a sampler.")
|
|
|
|
|
|
|
|
|
|
self.sampler = self.sampler.child_sampler
|
|
|
|
|
self.add_sampler(new_sampler)
|
|
|
|
@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset):
|
|
|
|
|
Return:
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
"""
|
|
|
|
|
if self.num_samples is None:
|
|
|
|
|
num_samples = 0
|
|
|
|
|
else:
|
|
|
|
|
num_samples = self.num_samples
|
|
|
|
|
|
|
|
|
|
if self.class_indexing is None:
|
|
|
|
|
class_indexing = dict()
|
|
|
|
|
else:
|
|
|
|
|
class_indexing = self.class_indexing
|
|
|
|
|
|
|
|
|
|
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
return self.num_samples
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, self.num_samples)
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
|
|
def get_class_indexing(self):
|
|
|
|
|
"""
|
|
|
|
|