|
|
|
@ -2504,11 +2504,12 @@ class GeneratorDataset(SourceDataset):
|
|
|
|
|
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
|
|
|
|
|
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
|
|
|
|
|
source[idx].
|
|
|
|
|
column_names (list[str]): List of column names of the dataset.
|
|
|
|
|
column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to
|
|
|
|
|
provide either column_names or schema.
|
|
|
|
|
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
|
|
|
|
|
If provided, sanity check will be performed on generator output.
|
|
|
|
|
schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
|
|
|
|
|
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
|
|
|
|
|
schema (Schema/String, optional): Path to the json schema file or schema object (default=None). Users are
|
|
|
|
|
required to provide either column_names or schema. If both are provided, schema will be used.
|
|
|
|
|
num_samples (int, optional): The number of samples to be included in the dataset
|
|
|
|
|
(default=None, all images).
|
|
|
|
|
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
|
|
|
@ -2555,8 +2556,8 @@ class GeneratorDataset(SourceDataset):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@check_generatordataset
|
|
|
|
|
def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1,
|
|
|
|
|
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
|
|
|
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
|
|
|
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
|
|
|
|
super().__init__(num_parallel_workers)
|
|
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
|
|
|
|
if self.sampler is not None and hasattr(source, "__getitem__"):
|
|
|
|
@ -2598,6 +2599,16 @@ class GeneratorDataset(SourceDataset):
|
|
|
|
|
else:
|
|
|
|
|
self.column_types = column_types
|
|
|
|
|
|
|
|
|
|
if schema is not None:
|
|
|
|
|
self.schema = schema
|
|
|
|
|
if not isinstance(schema, Schema):
|
|
|
|
|
self.schema = Schema(schema)
|
|
|
|
|
self.column_names = []
|
|
|
|
|
self.column_types = []
|
|
|
|
|
for col in self.schema.columns:
|
|
|
|
|
self.column_names.append(col["name"])
|
|
|
|
|
self.column_types.append(DataType(col["type"]))
|
|
|
|
|
|
|
|
|
|
def get_args(self):
|
|
|
|
|
args = super().get_args()
|
|
|
|
|
args["source"] = self.source
|
|
|
|
|