diff --git a/example/lstm_aclImdb/train.py b/example/lstm_aclImdb/train.py index 3d1c670f4e..08bea7c63d 100644 --- a/example/lstm_aclImdb/train.py +++ b/example/lstm_aclImdb/train.py @@ -71,7 +71,7 @@ if __name__ == '__main__': model = Model(network, loss, opt, {'acc': Accuracy()}) print("============== Starting Training ==============") - ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs) + ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index 96c83c3114..935ba4e2bd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { } Status RandomSampler::InitSampler() { - num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive."); rnd_.seed(seed_); if (replacement_ == false) { + num_samples_ = std::min(num_samples_, num_rows_); + shuffled_ids_.reserve(num_rows_); for (int64_t i = 0; i < num_rows_; i++) { shuffled_ids_.push_back(i); } std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); } else { + num_samples_ = std::min(num_samples_, user_num_samples_); dist = std::make_unique>(0, num_rows_ - 1); } + + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive."); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 695f364b7f..600d8c576b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { } // Handshake and init child first. - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); - } + RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); } CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc index 320bc601b9..0ae7a7d503 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} Status SubsetSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n"); CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); num_samples_ = subset_size_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h index 70ee80b0a4..5e8774f673 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 0631ade36a..ceca188112 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ - WeightedRandomSampler, Sampler + WeightedRandomSampler, SubsetSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show from .engine.graphdata import GraphData diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ae8f8dc243..f3703b3850 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -633,9 +633,9 @@ class Dataset: Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size of the original dataset. If after rounding, any size equals 0, an error will occur. All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. - randomize (bool): determines whether or not to split the data randomly. If true, the data - will be randomly split. Otherwise, each split will be created with consecutive rows - from the dataset. + randomize (bool, optional): determines whether or not to split the data randomly (default=True). + If true, the data will be randomly split. Otherwise, each split will be created with + consecutive rows from the dataset. Note: 1. Dataset cannot be sharded if split is going to be called. @@ -678,7 +678,8 @@ class Dataset: ds = copy.deepcopy(self) if randomize: # want to shuffle the same way every epoch before split - ds = ds.shuffle() + # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here + ds = ds.shuffle(10000) ds.reshuffle_each_epoch = False if rows_to_skip > 0: @@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset): >>> new_sampler = ds.DistributedSampler(10, 2) >>> data.use_sampler(new_sampler) """ + if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): + raise TypeError("new_sampler is not an instance of a sampler.") + self.sampler = self.sampler.child_sampler self.add_sampler(new_sampler) @@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset): def is_sharded(self): raise NotImplementedError("MappableDataset must implement is_sharded.") + def _get_sampler_dataset_size(self): + if self.sampler is not None: + return self.sampler.get_dataset_size() + + return None @check_split def split(self, sizes, randomize=True): @@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset): Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size of the original dataset. If after rounding, any size equals 0, an error will occur. All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. - randomize (bool): determines whether or not to split the data randomly. If true, the data - will be randomly split. Otherwise, each split will be created with consecutive rows - from the dataset. + randomize (bool, optional): determines whether or not to split the data randomly (default=True). + If true, the data will be randomly split. Otherwise, each split will be created with + consecutive rows from the dataset. Note: 1. Dataset should not be sharded if split is going to be called. Instead, create a @@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp): self.iterator = TupleIterator(self) - class RangeDataset(MappableDataset): """ A source dataset that reads and parses datasets stored on disk in a range. @@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset): else: num_samples = self.num_samples num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - return get_num_rows(num_rows, self.num_shards) + if rows_from_sampler is None: + return rows_per_shard + + return min(rows_from_sampler, rows_per_shard) def num_classes(self): """ @@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset): num_samples = self.num_samples num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return rows_per_shard - return get_num_rows(num_rows, self.num_shards) + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): if self.shuffle_level is None: @@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset): Return: Number, number of batches. """ - return self._dataset_size + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return self._dataset_size + + return min(rows_from_sampler, self._dataset_size) # manually set dataset_size as a temporary solution. def set_dataset_size(self, value): @@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset): class_indexing = self.class_indexing num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return rows_per_shard - return get_num_rows(num_rows, self.num_shards) + return min(rows_from_sampler, rows_per_shard) def num_classes(self): """ @@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset): num_samples = self.num_samples num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - return get_num_rows(num_rows, self.num_shards) + if rows_from_sampler is None: + return rows_per_shard + + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): if self.shuffle_level is None: @@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset): num_samples = self.num_samples num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return rows_per_shard - return get_num_rows(num_rows, self.num_shards) + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): if self.shuffle_level is None: @@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset): Return: Number, number of batches. """ - return num_samples + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return self.num_samples + + return min(rows_from_sampler, self.num_samples) def is_shuffled(self): return True @@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset): Return: Number, number of batches. """ - return self.num_samples + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return self.num_samples + + return min(rows_from_sampler, self.num_samples) def get_class_indexing(self): """ diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 8bf223251a..265d20f389 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -114,6 +114,9 @@ class Sampler: return self.child_sampler.is_sharded() + def get_dataset_size(self): + return self._get_indices().size + class BuiltinSampler: """ @@ -146,6 +149,12 @@ class BuiltinSampler: def is_sharded(self): raise NotImplementedError("Sampler must implement is_sharded.") + def get_dataset_size(self): + if self.child_sampler is not None: + return self.child_sampler.get_dataset_size() + + return None + class DistributedSampler(BuiltinSampler): """ @@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler): return self.child_sampler.is_sharded() + def get_dataset_size(self): + return self.num_samples + class SequentialSampler(BuiltinSampler): """ @@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler): return self.child_sampler.is_sharded() + def get_dataset_size(self): + return self.subset_size + class SubsetRandomSampler(BuiltinSampler): """ @@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler): return cde.MindrecordSubsetRandomSampler(self.indices) + def get_dataset_size(self): + return len(indices) + + class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). @@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler): return False return self.child_sampler.is_sharded() + + def get_dataset_size(self): + return self.num_samples diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 01a5f8d228..04b0602152 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -164,6 +165,35 @@ def test_python_sampler(): assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] +def test_subset_sampler(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + + def test_config(num_samples, start_index, subset_size): + sampler = ds.SubsetSampler(start_index, subset_size) + d = ds.ManifestDataset(manifest_file, sampler=sampler) + + res = [] + for item in d.create_dict_iterator(): + res.append(map[(item["image"].shape[0], item["label"].item())]) + + return res + + with pytest.raises(RuntimeError) as info: + test_config(5, 0, 0) + assert "subset_size <= 0" in str(info.value) + + assert test_config(5, 0, 1) == [0] + assert test_config(5, 0, 2) == [0, 1] + assert test_config(5, 0, 3) == [0, 1, 2] + assert test_config(5, 0, 4) == [0, 1, 2, 3] + assert test_config(5, 0, 5) == [0, 1, 2, 3, 4] + assert test_config(5, 1, 1) == [1] + assert test_config(5, 2, 3) == [2, 3, 4] + assert test_config(5, 3, 2) == [3, 4] + assert test_config(5, 4, 1) == [4] + + def test_sampler_chain(): manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} @@ -190,10 +220,26 @@ def test_sampler_chain(): assert test_config(5, 3) == [3] assert test_config(5, 4) == [4] +def test_add_sampler_invalid_input(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + data1 = ds.ManifestDataset(manifest_file) + + with pytest.raises(TypeError) as info: + data1.use_sampler(1) + assert "not an instance of a sampler" in str(info.value) + + with pytest.raises(TypeError) as info: + data1.use_sampler("sampler") + assert "not an instance of a sampler" in str(info.value) + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) test_sampler_py_api() test_python_sampler() + test_subset_sampler() test_sampler_chain() + test_add_sampler_invalid_input() diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index fa28b49181..0b546e5f6f 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -23,7 +23,11 @@ from util import config_get_set_num_parallel_workers manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} -def split_with_invalid_inputs(d): +text_file_dataset_path = "../data/dataset/testTextFileDataset/*" +text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", + "End of file.", "Good luck to everyone."] + +def split_with_invalid_inputs(d): with pytest.raises(ValueError) as info: s1, s2 = d.split([]) assert "sizes cannot be empty" in str(info.value) @@ -68,8 +72,8 @@ def split_with_invalid_inputs(d): s1, s2 = d.split([0.05, 0.95]) assert "percentage 0.05 is too small" in str(info.value) + def test_unmappable_invalid_input(): - text_file_dataset_path = "../data/dataset/testTextFileDataset/*" d = ds.TextFileDataset(text_file_dataset_path) split_with_invalid_inputs(d) @@ -78,11 +82,10 @@ def test_unmappable_invalid_input(): s1, s2 = d.split([4, 1]) assert "dataset should not be sharded before split" in str(info.value) + def test_unmappable_split(): - text_file_dataset_path = "../data/dataset/testTextFileDataset/*" - text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", - "End of file.", "Good luck to everyone."] original_num_parallel_workers = config_get_set_num_parallel_workers(4) + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) s1, s2 = d.split([4, 1], randomize=False) @@ -124,6 +127,142 @@ def test_unmappable_split(): assert s1_output == text_file_data[0:2] assert s2_output == text_file_data[2:] + + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_unmappable_randomize_deterministic(): + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + + # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] + ds.config.set_seed(53) + + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + for _ in range(10): + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + # note no overlap + assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] + assert s2_output == [text_file_data[3]] + + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_unmappable_randomize_repeatable(): + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + + # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] + ds.config.set_seed(53) + + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + num_epochs = 5 + s1 = s1.repeat(num_epochs) + s2 = s2.repeat(num_epochs) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + # note no overlap + assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs + assert s2_output == [text_file_data[3]] * num_epochs + + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_unmappable_get_dataset_size(): + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + assert d.get_dataset_size() == 5 + assert s1.get_dataset_size() == 4 + assert s2.get_dataset_size() == 1 + + +def test_unmappable_multi_split(): + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + + # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] + ds.config.set_seed(53) + + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([4, 1]) + + s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + assert s1_output == s1_correct_output + + # no randomize in second split + s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(item["text"].item().decode("utf8")) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(item["text"].item().decode("utf8")) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(item["text"].item().decode("utf8")) + + assert s1s1_output == [s1_correct_output[0]] + assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] + assert s1s3_output == [s1_correct_output[3]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + assert s2_output == [text_file_data[3]] + + # randomize in second split + # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0] + shuffled_ids = [2, 3, 1, 0] + + s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(item["text"].item().decode("utf8")) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(item["text"].item().decode("utf8")) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(item["text"].item().decode("utf8")) + + assert s1s1_output == [s1_correct_output[shuffled_ids[0]]] + assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]] + assert s1s3_output == [s1_correct_output[shuffled_ids[3]]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + assert s2_output == [text_file_data[3]] + # Restore configuration num_parallel_workers ds.config.set_num_parallel_workers(original_num_parallel_workers) @@ -137,6 +276,7 @@ def test_mappable_invalid_input(): s1, s2 = d.split([4, 1]) assert "dataset should not be sharded before split" in str(info.value) + def test_mappable_split_general(): d = ds.ManifestDataset(manifest_file, shuffle=False) d = d.take(5) @@ -183,6 +323,7 @@ def test_mappable_split_general(): assert s1_output == [0, 1] assert s2_output == [2, 3, 4] + def test_mappable_split_optimized(): d = ds.ManifestDataset(manifest_file, shuffle=False) @@ -228,9 +369,9 @@ def test_mappable_split_optimized(): assert s1_output == [0, 1] assert s2_output == [2, 3, 4] + def test_mappable_randomize_deterministic(): - # set arbitrary seed for shard after split - # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] ds.config.set_seed(53) d = ds.ManifestDataset(manifest_file, shuffle=False) @@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic(): assert s1_output == [0, 1, 3, 4] assert s2_output == [2] + def test_mappable_randomize_repeatable(): - # set arbitrary seed for shard after split - # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] ds.config.set_seed(53) d = ds.ManifestDataset(manifest_file, shuffle=False) @@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable(): assert s1_output == [0, 1, 3, 4] * num_epochs assert s2_output == [2] * num_epochs + def test_mappable_sharding(): # set arbitrary seed for repeatability for shard after split - # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] ds.config.set_seed(53) num_epochs = 5 @@ -336,12 +478,94 @@ def test_mappable_sharding(): assert s2_output == [2] assert d2s2_output == [2] + +def test_mappable_get_dataset_size(): + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([4, 1]) + + assert d.get_dataset_size() == 5 + assert s1.get_dataset_size() == 4 + assert s2.get_dataset_size() == 1 + + +def test_mappable_multi_split(): + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] + ds.config.set_seed(53) + + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([4, 1]) + + s1_correct_output = [0, 1, 3, 4] + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + assert s1_output == s1_correct_output + + # no randomize in second split + s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1s1_output == [s1_correct_output[0]] + assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] + assert s1s3_output == [s1_correct_output[3]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + assert s2_output == [2] + + # randomize in second split + # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0] + random_sampler_ids = [3, 1, 2, 0] + + s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]] + assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]] + assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + assert s2_output == [2] + + if __name__ == '__main__': test_unmappable_invalid_input() test_unmappable_split() + test_unmappable_randomize_deterministic() + test_unmappable_randomize_repeatable() + test_unmappable_get_dataset_size() + test_unmappable_multi_split() test_mappable_invalid_input() test_mappable_split_general() test_mappable_split_optimized() test_mappable_randomize_deterministic() test_mappable_randomize_repeatable() test_mappable_sharding() + test_mappable_get_dataset_size() + test_mappable_multi_split()