|
|
|
@ -25,7 +25,6 @@ import mindspore._c_dataengine as cde
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
from ..core import validator_helpers as validator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|
|
|
|
"""
|
|
|
|
|
Create sampler based on user input.
|
|
|
|
@ -57,8 +56,14 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|
|
|
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
|
|
|
|
if isinstance(input_sampler, BuiltinSampler):
|
|
|
|
|
return input_sampler
|
|
|
|
|
return SubsetSampler(input_sampler, num_samples)
|
|
|
|
|
|
|
|
|
|
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
|
|
|
|
|
return SubsetSampler(input_sampler, num_samples)
|
|
|
|
|
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
|
|
|
|
|
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler
|
|
|
|
|
return IterSampler(input_sampler, num_samples)
|
|
|
|
|
if isinstance(input_sampler, int):
|
|
|
|
|
return SubsetSampler([input_sampler])
|
|
|
|
|
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
|
|
|
|
|
if shuffle is None:
|
|
|
|
|
if num_shards is not None:
|
|
|
|
|
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
|
|
|
@ -619,13 +624,6 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, indices, num_samples=None):
|
|
|
|
|
def _is_iterable(obj):
|
|
|
|
|
try:
|
|
|
|
|
iter(obj)
|
|
|
|
|
except TypeError:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
|
|
|
|
if number_of_samples is None:
|
|
|
|
|
return list(sampler)
|
|
|
|
@ -635,7 +633,7 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
|
|
|
|
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
|
|
|
|
|
|
|
|
|
if not isinstance(indices, str) and _is_iterable(indices):
|
|
|
|
|
if not isinstance(indices, str) and validator.is_iterable(indices):
|
|
|
|
|
indices = _get_sample_ids_as_list(indices, num_samples)
|
|
|
|
|
elif isinstance(indices, int):
|
|
|
|
|
indices = [indices]
|
|
|
|
@ -725,6 +723,42 @@ class SubsetRandomSampler(SubsetSampler):
|
|
|
|
|
return c_sampler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IterSampler(Sampler):
|
|
|
|
|
"""
|
|
|
|
|
User provided an iterable object without inheriting from our Sampler class.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
This class exists to allow handshake logic between dataset operators and user defined samplers.
|
|
|
|
|
By constructing this object we avoid the user having to inherit from our Sampler class.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sampler (iterable object): an user defined iterable object.
|
|
|
|
|
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class MySampler():
|
|
|
|
|
>>> def __iter__(self):
|
|
|
|
|
>>> for i in range(99, -1, -1):
|
|
|
|
|
>>> yield i
|
|
|
|
|
|
|
|
|
|
>>> # creates an IterSampler
|
|
|
|
|
>>> sampler = ds.IterSampler(sampler=MySampler())
|
|
|
|
|
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
|
|
|
|
|
... num_parallel_workers=8,
|
|
|
|
|
... sampler=sampler)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, sampler, num_samples=None):
|
|
|
|
|
if num_samples is None:
|
|
|
|
|
num_samples = len(list(sampler))
|
|
|
|
|
super().__init__(num_samples=num_samples)
|
|
|
|
|
self.sampler = sampler
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
return iter(self.sampler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WeightedRandomSampler(BuiltinSampler):
|
|
|
|
|
"""
|
|
|
|
|
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
|
|
|
|