diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 54a2b15e9f..235620604e 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1255,7 +1255,7 @@ class Dataset: del api_tree @check_tuple_iterator - def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False): + def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): """ Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data. @@ -1269,6 +1269,8 @@ class Dataset: (default=-1, iterator can be iterated infinite number of epochs) output_numpy (bool, optional): Whether or not to output NumPy datatype. If output_numpy=False, iterator will output MSTensor (default=False). + do_copy (bool, optional): when output data type is mindspore.Tensor, + use this param to select the conversion method, only take False for better performance (default=True). Returns: Iterator, list of ndarrays. @@ -1290,7 +1292,7 @@ class Dataset: if Dataset._noop_mode(): return DummyIterator(self, 'tuple') - return TupleIterator(self, columns, num_epochs, output_numpy) + return TupleIterator(self, columns, num_epochs, output_numpy, do_copy) @check_dict_iterator def create_dict_iterator(self, num_epochs=-1, output_numpy=False): @@ -2788,7 +2790,7 @@ class TransferDataset(Dataset): def create_dict_iterator(self, num_epochs=-1, output_numpy=False): raise RuntimeError("TransferDataset is not iterable.") - def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False): + def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): raise RuntimeError("TransferDataset is not iterable.") def __iter__(self): diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index c71f3c3b4e..945e9ae696 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -63,7 +63,7 @@ class Iterator: dataset: Dataset to be iterated over """ - def __init__(self, dataset, num_epochs=-1, output_numpy=False): + def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True): self._col_names = None # create a copy of tree and work on it. @@ -80,7 +80,10 @@ class Iterator: self._transform_tensor = lambda t: t.as_array() if not output_numpy: - self._transform_tensor = lambda t: Tensor(t.as_array()) + if do_copy: + self._transform_tensor = lambda t: Tensor(t.as_array()) + else: + self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array()) self._index = 0 # todo remove next when ContextManager is done @@ -179,13 +182,13 @@ class TupleIterator(Iterator): The derived class of Iterator with list type. """ - def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False): + def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False, do_copy=True): if columns is not None: if not isinstance(columns, list): columns = [columns] # todo: move next to IR dataset = dataset.project(columns) - super().__init__(dataset, num_epochs, output_numpy) + super().__init__(dataset, num_epochs, output_numpy, do_copy) def _get_next(self): """ diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index db4801a693..939af41f75 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -298,7 +298,7 @@ def check_tuple_iterator(method): @wraps(method) def new_method(self, *args, **kwargs): - [columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs) + [columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs) nreq_param_bool = ['output_numpy'] validate_dataset_param_value(nreq_param_bool, param_dict, bool) if num_epochs is not None: diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index b561c722e5..7eeea138f4 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -394,7 +394,7 @@ class _DatasetIterNormal: self.dataset = dataset self.device_num = _get_device_num() self.global_rank = _get_global_rank() - self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num) + self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num, do_copy=False) def __iter__(self): return self diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 3c51baaf3e..ac53f96f16 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -55,7 +55,7 @@ class MindData: self.send_epoch_end = send_epoch_end return self - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self.__iter__() def send(self, num_epochs=-1): diff --git a/tests/st/auto_parallel/optimizer_parallel.py b/tests/st/auto_parallel/optimizer_parallel.py index c62269b226..fdcdd3ef87 100644 --- a/tests/st/auto_parallel/optimizer_parallel.py +++ b/tests/st/auto_parallel/optimizer_parallel.py @@ -125,7 +125,7 @@ class FakeData: def set_label_onehot(self, is_onehot=True): self.is_onehot = is_onehot - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): _ = num_epochs return self diff --git a/tests/st/auto_parallel/parallel_strategy_search.py b/tests/st/auto_parallel/parallel_strategy_search.py index 057c45b61f..561ec9ea4a 100644 --- a/tests/st/auto_parallel/parallel_strategy_search.py +++ b/tests/st/auto_parallel/parallel_strategy_search.py @@ -128,7 +128,7 @@ class FakeData: def set_label_onehot(self, is_onehot=True): self.is_onehot = is_onehot - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): _ = num_epochs return self diff --git a/tests/st/pynative/loss_scale/test_loss_scale.py b/tests/st/pynative/loss_scale/test_loss_scale.py index ee1d845974..6fdef1af71 100644 --- a/tests/st/pynative/loss_scale/test_loss_scale.py +++ b/tests/st/pynative/loss_scale/test_loss_scale.py @@ -60,7 +60,7 @@ class MindData: def output_shapes(self): return self._output_shapes - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self @property diff --git a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py index 0431604411..16217be6b9 100644 --- a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py +++ b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py @@ -152,7 +152,7 @@ class DatasetLenet(): def get_repeat_count(self): return 1 - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self def test_double_subgraphs_train(): diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index da3ded3209..d12a028167 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -275,7 +275,7 @@ class DatasetLenet(): def get_repeat_count(self): return 1 - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self diff --git a/tests/ut/python/parallel/test_bias_add.py b/tests/ut/python/parallel/test_bias_add.py index 5c3f26bbff..8a47762519 100644 --- a/tests/ut/python/parallel/test_bias_add.py +++ b/tests/ut/python/parallel/test_bias_add.py @@ -61,7 +61,7 @@ class DatasetLenet(): def get_repeat_count(self): return 1 - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self diff --git a/tests/ut/python/parallel/test_gather_v2_primitive.py b/tests/ut/python/parallel/test_gather_v2_primitive.py index ac5dfc70ee..0e2c90eed4 100644 --- a/tests/ut/python/parallel/test_gather_v2_primitive.py +++ b/tests/ut/python/parallel/test_gather_v2_primitive.py @@ -59,7 +59,7 @@ class Dataset(): def get_repeat_count(self): return 1 - def create_tuple_iterator(self, num_epochs=-1): + def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index bea964392d..a77fe2f968 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -51,7 +51,7 @@ class DatasetLenet(): def get_batch_size(self): return 32 - def create_tuple_iterator(self, num_epochs=1): + def create_tuple_iterator(self, num_epochs=1, do_copy=True): return self