minddata fix generatordataset get_dataset_size issue

pull/11058/head
xiefangqi 4 years ago
parent 274e0aa750
commit bd0cf9ed43

@ -3780,6 +3780,18 @@ class GeneratorDataset(MappableDataset):
self.schema = schema self.schema = schema
if not isinstance(schema, Schema): if not isinstance(schema, Schema):
self.schema = Schema(schema) self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy.
self.dataset_size = None
if hasattr(self.source, "__len__"):
if not self.num_shards:
self.dataset_size = len(self.source)
else:
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
if id(self) in memodict: if id(self) in memodict:
@ -3838,24 +3850,16 @@ class GeneratorDataset(MappableDataset):
return self.sampler.is_sharded() return self.sampler.is_sharded()
def parse(self, children=None): def parse(self, children=None):
dataset_size = -1 if self.dataset_size is None:
if hasattr(self.source, "__len__"): self.dataset_size = -1
if not self.num_shards:
dataset_size = len(self.source)
else:
dataset_size = math.ceil(len(self.source) / self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < dataset_size:
dataset_size = rows_from_sampler
if self.schema is None: if self.schema is None:
return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize( return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize(
dataset_size) \ self.dataset_size) \
.SetNumWorkers(self.num_parallel_workers) .SetNumWorkers(self.num_parallel_workers)
schema = self.schema schema = self.schema
if isinstance(schema, Schema): if isinstance(schema, Schema):
schema = self.schema.cpp_schema schema = self.schema.cpp_schema
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(dataset_size).SetNumWorkers( return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers(
self.num_parallel_workers) self.num_parallel_workers)

Loading…
Cancel
Save