diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3225ebc806..62c8e75ca9 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -25,6 +25,7 @@ import os import random import uuid import multiprocessing +import queue from enum import Enum from importlib import import_module @@ -2124,6 +2125,142 @@ def _cpp_sampler_fn(sampler, dataset): yield tuple([np.array(x) for x in val]) +def _cpp_sampler_fn_mp(sampler, dataset, num_worker): + """ + Multiprocessing generator function wrapper for mappable dataset with cpp sampler + """ + indices = sampler.get_indices() + return _sampler_fn_mp(indices, dataset, num_worker) + + +def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): + """ + Multiprocessing generator function wrapper for mappable dataset with python sampler + """ + indices = _fetch_py_sampler_indices(sampler, num_samples) + return _sampler_fn_mp(indices, dataset, num_worker) + + +def _fetch_py_sampler_indices(sampler, num_samples): + """ + Indices fetcher for python sampler + """ + if num_samples is not None: + sampler_iter = iter(sampler) + ret = [] + for _ in range(num_samples): + try: + val = next(sampler_iter) + ret.append(val) + except StopIteration: + break + return ret + return [i for i in sampler] + + +def _fill_worker_indices(workers, indices, idx): + """ + Worker index queue filler, fill worker index queue in round robin order + """ + num_worker = len(workers) + while idx < len(indices): + try: + workers[idx % num_worker].put(indices[idx]) + idx += 1 + except queue.Full: + break + return idx + + +def _sampler_fn_mp(indices, dataset, num_worker): + """ + Multiprocessing generator function wrapper master process + """ + workers = [] + # Event for end of epoch + eoe = multiprocessing.Event() + + # Create workers + for _ in range(num_worker): + worker = _GeneratorWorker(dataset, eoe) + worker.daemon = True + workers.append(worker) + + # Fill initial index queues + idx_cursor = 0 + idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + + # Start all workers + for w in workers: + w.start() + + # Fetch results + for i in range(len(indices)): + # Fetch result and put index + try: + result = workers[i % num_worker].get() + except queue.Empty: + raise Exception("Generator worker process timeout") + except KeyboardInterrupt: + for w in workers: + w.terminate() + w.join() + raise Exception("Generator worker receives KeyboardInterrupt") + if idx_cursor < len(indices): + idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + # Set eoe event once all indices are sent + if idx_cursor == len(indices) and not eoe.is_set(): + eoe.set() + yield tuple([np.array(x) for x in result]) + + +def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): + """ + Multiprocessing generator worker process loop + """ + while True: + # Fetch index, block + try: + idx = idx_queue.get() + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + if idx is None: + # When the queue is out of scope from master process, a None item can be fetched from the queue. + # Upon receiving None, worker process should check if EOE is set. + assert eoe.is_set(), "" + return + # Fetch data, any exception from __getitem__ will terminate worker and timeout master process + result = dataset[idx] + # Send data, block + try: + result_queue.put(result) + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + del result, idx + + +class _GeneratorWorker(multiprocessing.Process): + """ + Worker process for multiprocess Generator + """ + def __init__(self, dataset, eoe): + self.idx_queue = multiprocessing.Queue(16) + self.res_queue = multiprocessing.Queue(16) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) + + def put(self, item): + """ + Put function for worker index queue. Never block. Raise queue.Full on failure. + """ + self.idx_queue.put_nowait(item) + + def get(self): + """ + Get function for worker result queue. Block with timeout. + """ + return self.res_queue.get(timeout=5) + + class GeneratorDataset(SourceDataset): """ A source dataset that generate data from python by invoking python data source each epoch. @@ -2171,6 +2308,7 @@ class GeneratorDataset(SourceDataset): If the schema is not provided, the meta data from column_names and column_types is considered the schema. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). + num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. (default=None, expected order behavior shown in the table). sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is @@ -2229,9 +2367,15 @@ class GeneratorDataset(SourceDataset): sampler_instance.set_num_rows(len(source)) sampler_instance.set_num_samples(num_samples) sampler_instance.initialize() - self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) + if num_parallel_workers > 1: + self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) + else: + self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) else: - self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) + if num_parallel_workers > 1: + self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) + else: + self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) else: try: iter(source) diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py index c224c5a2ea..4daf952eba 100644 --- a/tests/ut/python/dataset/test_generator.py +++ b/tests/ut/python/dataset/test_generator.py @@ -391,6 +391,80 @@ def test_case_13(): i = i + 1 +def test_case_14(): + """ + Test 1D Generator MP + CPP sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_case_15(): + """ + Test 1D Generator MP + Python sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + sampler = [x for x in range(256)] + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_case_16(): + """ + Test multi column generator Mp + CPP sampler + """ + logger.info("Test multi column generator") + + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler()) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["col0"], golden) + golden = np.array([i + 1]) + assert np.array_equal(item["col1"], golden) + i = i + 1 + + +def test_case_17(): + """ + Test multi column generator Mp + Python sampler + """ + logger.info("Test multi column generator") + + sampler = [x for x in range(256)] + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["col0"], golden) + golden = np.array([i + 1]) + assert np.array_equal(item["col1"], golden) + i = i + 1 + + def test_case_error_1(): def generator_np(): for i in range(64): @@ -506,6 +580,25 @@ def test_num_samples_underflow(): count = count + 1 assert count == 64 +def manual_test_keyborad_interrupt(): + """ + Test keyborad_interrupt + """ + logger.info("Test 1D Generator MP : 0 - 63") + + class MyDS(): + def __getitem__(self, item): + while True: + pass + + def __len__(self): + return 1024 + + ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + pass + if __name__ == "__main__": test_case_0() @@ -522,6 +615,10 @@ if __name__ == "__main__": test_case_11() test_case_12() test_case_13() + test_case_14() + test_case_15() + test_case_16() + test_case_17() test_case_error_1() test_case_error_2() test_case_error_3() @@ -529,3 +626,5 @@ if __name__ == "__main__": test_sequential_sampler() test_distributed_sampler() test_random_sampler() + +