diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc index 82299edba2..512f059311 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc @@ -167,7 +167,6 @@ std::shared_ptr toSamplerObj(py::handle py_sampler, bool isMindDatas if (py_sampler) { std::shared_ptr sampler_obj; if (!isMindDataset) { - // Common Sampler auto parse = py::reinterpret_borrow(py_sampler).attr("parse"); sampler_obj = parse().cast>(); } else { diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 9d5478be8d..9d9f98e772 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -44,6 +44,21 @@ valid_detype = [ "uint32", "uint64", "float16", "float32", "float64", "string" ] +def is_iterable(obj): + """ + Helper function to check if object is iterable. + + Args: + obj (any): object to check if iterable + + Returns: + bool, true if object iteratable + """ + try: + iter(obj) + except TypeError: + return False + return True def pad_arg_name(arg_name): if arg_name != "": diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 748e20dda9..b61f5cffbe 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -25,7 +25,6 @@ import mindspore._c_dataengine as cde import mindspore.dataset as ds from ..core import validator_helpers as validator - def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): """ Create sampler based on user input. @@ -57,8 +56,14 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) if isinstance(input_sampler, BuiltinSampler): return input_sampler - return SubsetSampler(input_sampler, num_samples) - + if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)): + return SubsetSampler(input_sampler, num_samples) + if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler): + # in this case, the user passed in their own sampler object that's not of type BuiltinSampler + return IterSampler(input_sampler, num_samples) + if isinstance(input_sampler, int): + return SubsetSampler([input_sampler]) + raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) if shuffle is None: if num_shards is not None: # If shuffle is not specified, sharding enabled, use distributed random sampler @@ -619,13 +624,6 @@ class SubsetSampler(BuiltinSampler): """ def __init__(self, indices, num_samples=None): - def _is_iterable(obj): - try: - iter(obj) - except TypeError: - return False - return True - def _get_sample_ids_as_list(sampler, number_of_samples=None): if number_of_samples is None: return list(sampler) @@ -635,7 +633,7 @@ class SubsetSampler(BuiltinSampler): return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] - if not isinstance(indices, str) and _is_iterable(indices): + if not isinstance(indices, str) and validator.is_iterable(indices): indices = _get_sample_ids_as_list(indices, num_samples) elif isinstance(indices, int): indices = [indices] @@ -725,6 +723,42 @@ class SubsetRandomSampler(SubsetSampler): return c_sampler +class IterSampler(Sampler): + """ + User provided an iterable object without inheriting from our Sampler class. + + Note: + This class exists to allow handshake logic between dataset operators and user defined samplers. + By constructing this object we avoid the user having to inherit from our Sampler class. + + Args: + sampler (iterable object): an user defined iterable object. + num_samples (int, optional): Number of elements to sample (default=None, all elements). + + Examples: + >>> class MySampler(): + >>> def __iter__(self): + >>> for i in range(99, -1, -1): + >>> yield i + + >>> # creates an IterSampler + >>> sampler = ds.IterSampler(sampler=MySampler()) + >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, + ... num_parallel_workers=8, + ... sampler=sampler) + + """ + + def __init__(self, sampler, num_samples=None): + if num_samples is None: + num_samples = len(list(sampler)) + super().__init__(num_samples=num_samples) + self.sampler = sampler + + def __iter__(self): + return iter(self.sampler) + + class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 031f5d93f3..f625ab62d8 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -141,9 +141,23 @@ def test_python_sampler(): assert data[0].asnumpy() == (np.array(i),) i = i - 1 + # This 2nd case is the one that exhibits the same behavior as the case above without inheritance + def test_generator_iter_sampler(): + class MySampler(): + def __iter__(self): + for i in range(99, -1, -1): + yield i + + data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler()) + i = 99 + for data in data1: + assert data[0].asnumpy() == (np.array(i),) + i = i - 1 + assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] test_generator() + test_generator_iter_sampler() def test_sequential_sampler2():