|
|
@ -3780,6 +3780,18 @@ class GeneratorDataset(MappableDataset):
|
|
|
|
self.schema = schema
|
|
|
|
self.schema = schema
|
|
|
|
if not isinstance(schema, Schema):
|
|
|
|
if not isinstance(schema, Schema):
|
|
|
|
self.schema = Schema(schema)
|
|
|
|
self.schema = Schema(schema)
|
|
|
|
|
|
|
|
# Move get dataset_size by len from parse to here, because self.source will
|
|
|
|
|
|
|
|
# lose attribution of '__len__' after deepcopy.
|
|
|
|
|
|
|
|
self.dataset_size = None
|
|
|
|
|
|
|
|
if hasattr(self.source, "__len__"):
|
|
|
|
|
|
|
|
if not self.num_shards:
|
|
|
|
|
|
|
|
self.dataset_size = len(self.source)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
|
|
|
|
|
|
|
self.dataset_size = rows_from_sampler
|
|
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
if id(self) in memodict:
|
|
|
|
if id(self) in memodict:
|
|
|
@ -3838,24 +3850,16 @@ class GeneratorDataset(MappableDataset):
|
|
|
|
return self.sampler.is_sharded()
|
|
|
|
return self.sampler.is_sharded()
|
|
|
|
|
|
|
|
|
|
|
|
def parse(self, children=None):
|
|
|
|
def parse(self, children=None):
|
|
|
|
dataset_size = -1
|
|
|
|
if self.dataset_size is None:
|
|
|
|
if hasattr(self.source, "__len__"):
|
|
|
|
self.dataset_size = -1
|
|
|
|
if not self.num_shards:
|
|
|
|
|
|
|
|
dataset_size = len(self.source)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
dataset_size = math.ceil(len(self.source) / self.num_shards)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < dataset_size:
|
|
|
|
|
|
|
|
dataset_size = rows_from_sampler
|
|
|
|
|
|
|
|
if self.schema is None:
|
|
|
|
if self.schema is None:
|
|
|
|
return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize(
|
|
|
|
return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize(
|
|
|
|
dataset_size) \
|
|
|
|
self.dataset_size) \
|
|
|
|
.SetNumWorkers(self.num_parallel_workers)
|
|
|
|
.SetNumWorkers(self.num_parallel_workers)
|
|
|
|
schema = self.schema
|
|
|
|
schema = self.schema
|
|
|
|
if isinstance(schema, Schema):
|
|
|
|
if isinstance(schema, Schema):
|
|
|
|
schema = self.schema.cpp_schema
|
|
|
|
schema = self.schema.cpp_schema
|
|
|
|
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(dataset_size).SetNumWorkers(
|
|
|
|
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers(
|
|
|
|
self.num_parallel_workers)
|
|
|
|
self.num_parallel_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|