@ -1748,14 +1748,70 @@ class MindDataset(SourceDataset):
return num_rows
def ds_fn(dataset):
for val in dataset:
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
def _iter_fn(dataset, num_samples):
Generator function wrapper for iterable dataset
if num_samples is not None:
ds_iter = iter(dataset)
for _ in range(num_samples):
val = next(ds_iter)
except StopIteration:
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
for val in dataset:
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
def _generator_fn(generator, num_samples):
Generator function wrapper for generator function dataset
if num_samples is not None:
gen_iter = generator()
for _ in range(num_samples):
val = next(gen_iter)
except StopIteration:
yield val
gen_iter = generator()
for val in gen_iter:
yield val
def sampler_fn(sampler, dataset):
for i in sampler:
def _py_sampler_fn(sampler, num_samples, dataset):
Generator function wrapper for mappable dataset with python sampler
if num_samples is not None:
sampler_iter = iter(sampler)
for _ in range(num_samples):
idx = next(sampler_iter)
except StopIteration:
val = dataset[idx]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
for i in sampler:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
def _cpp_sampler_fn(sampler, dataset):
Generator function wrapper for mappable dataset with cpp sampler
indices = sampler.get_indices()
for i in indices:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
@ -1763,49 +1819,122 @@ def sampler_fn(sampler, dataset):
class GeneratorDataset(SourceDataset):
A source dataset that generate data from calling generator function each epoch.
A source dataset that generate data from python by invoking python data source each epoch.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
generator_function (callable):
A callable object that returns an Generator object that supports the iter() protocol.
Generator object is required to return a tuple of numpy array as a row of the dataset on next().
source (Callable/Iterable/Random Accessible):
A generator callable object, an iterable python object or a random accessible python object.
Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
column_names (list[str]): List of column names of the dataset.
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
If provided, sanity check will be performed on generator output.
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all images).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
>>> import mindspore.dataset as ds
>>> # 1) generator function that generates multi-dimensional data
>>> import mindspore.dataengine as de
>>> # 1) Multidimensional generator function as callable input
>>> def generator_md():
>>> for i in range(64):
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> # create multi_dimension_generator_dataset with GeneratorMD() and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2) generator function that generates multi-columns data
>>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2) Multi-column generator function as callable input
>>> def generator_mc(maxid = 64):
>>> for i in range(maxid):
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
>>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2"
>>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"])
>>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
>>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"])
>>> # 3) Iterable dataset as iterable input
>>> class MyIterable():
>>> def __iter__(self):
>>> return # User implementation
>>> # create iterable_generator_dataset with MyIterable object
>>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
>>> # 4) Random accessible dataset as Random accessible input
>>> class MyRA():
>>> def __getitem__(self, index):
>>> return # User implementation
>>> # create ra_generator_dataset with MyRA object
>>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
>>> # List/Dict/Tuple is also random accessible
>>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
>>> # 5) Built-in Sampler
>>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
def __init__(self, generator_function, column_names, column_types=None, prefetch_size=None, sampler=None):
if sampler is not None:
self.generator_function = (lambda: sampler_fn(sampler, generator_function))
def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
if self.sampler is not None and hasattr(source, "__getitem__"):
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
if num_samples is None:
num_samples = len(source)
sampler_instance = self.sampler.create()
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
# test to see if generator_function is iterable
except TypeError:
# generator_function was not iterable, assume it is a function
self.generator_function = generator_function
# Use generator function if input callable
self.source = (lambda: _generator_fn(source, num_samples))
# generator_function was iterable, build a function around it
self.generator_function = (lambda: ds_fn(generator_function))
# Use iterator function if input is iterable
# Random accessible input is also iterable
self.source = (lambda: _iter_fn(source, num_samples))
self.column_names = column_names
@ -1813,17 +1942,12 @@ class GeneratorDataset(SourceDataset):
self.column_types = mstypelist_to_detypelist(column_types)
self.column_types = column_types
self.distribution = ""
self.prefetch_size = prefetch_size
self.sampler = sampler
def get_args(self):
args = super().get_args()
args["generator_function"] = self.generator_function
args["source"] = self.source
args["column_names"] = self.column_names
args["column_types"] = self.column_types
args["prefetch_size"] = self.prefetch_size
args["sampler"] = self.sampler
return args
def get_dataset_size(self):