From cb9c6fad863b2d8bacd0b686af496c3c5f305eec Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Tue, 28 Jul 2020 19:42:48 +0800 Subject: [PATCH] fix numpyslice issue --- mindspore/dataset/engine/datasets.py | 60 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index afacb68089..caf3857e20 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3219,33 +3219,9 @@ class GeneratorDataset(MappableDataset): def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) + self.source = source self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - if self.sampler is not None and hasattr(source, "__getitem__"): - if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, - samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler, samplers.Sampler)): - sampler_instance = self.sampler.create() - sampler_instance.set_num_rows(len(source)) - sampler_instance.initialize() - if num_parallel_workers > 1: - self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) - else: - self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) - else: - if num_parallel_workers > 1: - self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) - else: - self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) - else: - try: - iter(source) - except TypeError: - # Use generator function if input callable - self.source = (lambda: _generator_fn(source, num_samples)) - else: - # Use iterator function if input is iterable - # Random accessible input is also iterable - self.source = (lambda: _iter_fn(source, num_samples)) + self.num_samples = num_samples if column_names is not None and not isinstance(column_names, list): column_names = [column_names] @@ -3310,9 +3286,35 @@ class GeneratorDataset(MappableDataset): new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict) - - new_op.source = self.source - new_op.sampler = self.sampler + new_op.num_samples = copy.deepcopy(self.num_samples, memodict) + + new_op.sampler = copy.deepcopy(self.sampler) + if new_op.sampler is not None and hasattr(self.source, "__getitem__"): + if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, + samplers.RandomSampler, samplers.SubsetRandomSampler, + samplers.WeightedRandomSampler, samplers.Sampler)): + sampler_instance = new_op.sampler.create() + sampler_instance.set_num_rows(len(self.source)) + sampler_instance.initialize() + if new_op.num_parallel_workers > 1: + new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) + else: + new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) + else: + if new_op.num_parallel_workers > 1: + new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) + else: + new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) + else: + try: + iter(self.source) + except TypeError: + # Use generator function if input callable + new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples)) + else: + # Use iterator function if input is iterable + # Random accessible input is also iterable + new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples)) return new_op