From d4c93575e7dc68f04f12369357553d9dc5d67dec Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Mon, 25 May 2020 17:19:08 -0400 Subject: [PATCH] fixed bug for split, RandomSampler and some other cleanup add another test case typo merge conflict another PR changed testing behavior, updated test cases in this commit added input check for use_sampler addressed code review comments fixed pylint, not related to my changes --- example/lstm_aclImdb/train.py | 2 +- .../source/sampler/random_sampler.cc | 11 +- .../datasetops/source/sampler/sampler.cc | 4 +- .../source/sampler/subset_sampler.cc | 6 +- .../source/sampler/subset_sampler.h | 2 +- mindspore/dataset/__init__.py | 2 +- mindspore/dataset/engine/datasets.py | 80 ++++-- mindspore/dataset/engine/samplers.py | 22 ++ tests/ut/python/dataset/test_sampler.py | 46 ++++ tests/ut/python/dataset/test_split.py | 244 +++++++++++++++++- 10 files changed, 381 insertions(+), 38 deletions(-) diff --git a/example/lstm_aclImdb/train.py b/example/lstm_aclImdb/train.py index 3d1c670f4e..08bea7c63d 100644 --- a/example/lstm_aclImdb/train.py +++ b/example/lstm_aclImdb/train.py @@ -71,7 +71,7 @@ if __name__ == '__main__': model = Model(network, loss, opt, {'acc': Accuracy()}) print("============== Starting Training ==============") - ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs) + ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index 96c83c3114..935ba4e2bd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { } Status RandomSampler::InitSampler() { - num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive."); rnd_.seed(seed_); if (replacement_ == false) { + num_samples_ = std::min(num_samples_, num_rows_); + shuffled_ids_.reserve(num_rows_); for (int64_t i = 0; i < num_rows_; i++) { shuffled_ids_.push_back(i); } std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); } else { + num_samples_ = std::min(num_samples_, user_num_samples_); dist = std::make_unique>(0, num_rows_ - 1); } + + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive."); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 695f364b7f..600d8c576b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { } // Handshake and init child first. - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); - } + RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); } CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc index 320bc601b9..0ae7a7d503 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} Status SubsetSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n"); CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); num_samples_ = subset_size_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h index 70ee80b0a4..5e8774f673 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 0631ade36a..ceca188112 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ - WeightedRandomSampler, Sampler + WeightedRandomSampler, SubsetSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show from .engine.graphdata import GraphData diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ae8f8dc243..f3703b3850 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -633,9 +633,9 @@ class Dataset: Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size of the original dataset. If after rounding, any size equals 0, an error will occur. All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. - randomize (bool): determines whether or not to split the data randomly. If true, the data - will be randomly split. Otherwise, each split will be created with consecutive rows - from the dataset. + randomize (bool, optional): determines whether or not to split the data randomly (default=True). + If true, the data will be randomly split. Otherwise, each split will be created with + consecutive rows from the dataset. Note: 1. Dataset cannot be sharded if split is going to be called. @@ -678,7 +678,8 @@ class Dataset: ds = copy.deepcopy(self) if randomize: # want to shuffle the same way every epoch before split - ds = ds.shuffle() + # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here + ds = ds.shuffle(10000) ds.reshuffle_each_epoch = False if rows_to_skip > 0: @@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset): >>> new_sampler = ds.DistributedSampler(10, 2) >>> data.use_sampler(new_sampler) """ + if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): + raise TypeError("new_sampler is not an instance of a sampler.") + self.sampler = self.sampler.child_sampler self.add_sampler(new_sampler) @@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset): def is_sharded(self): raise NotImplementedError("MappableDataset must implement is_sharded.") + def _get_sampler_dataset_size(self): + if self.sampler is not None: + return self.sampler.get_dataset_size() + + return None @check_split def split(self, sizes, randomize=True): @@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset): Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size of the original dataset. If after rounding, any size equals 0, an error will occur. All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. - randomize (bool): determines whether or not to split the data randomly. If true, the data - will be randomly split. Otherwise, each split will be created with consecutive rows - from the dataset. + randomize (bool, optional): determines whether or not to split the data randomly (default=True). + If true, the data will be randomly split. Otherwise, each split will be created with + consecutive rows from the dataset. Note: 1. Dataset should not be sharded if split is going to be called. Instead, create a @@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp): self.iterator = TupleIterator(self) - class RangeDataset(MappableDataset): """ A source dataset that reads and parses datasets stored on disk in a range. @@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset): else: num_samples = self.num_samples num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - return get_num_rows(num_rows, self.num_shards) + if rows_from_sampler is None: + return rows_per_shard + + return min(rows_from_sampler, rows_per_shard) def num_classes(self): """ @@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset): num_samples = self.num_samples num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return rows_per_shard - return get_num_rows(num_rows, self.num_shards) + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): if self.shuffle_level is None: @@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset): Return: Number, number of batches. """ - return self._dataset_size + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return self._dataset_size + + return min(rows_from_sampler, self._dataset_size) # manually set dataset_size as a temporary solution. def set_dataset_size(self, value): @@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset): class_indexing = self.class_indexing num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return rows_per_shard - return get_num_rows(num_rows, self.num_shards) + return min(rows_from_sampler, rows_per_shard) def num_classes(self): """ @@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset): num_samples = self.num_samples num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - return get_num_rows(num_rows, self.num_shards) + if rows_from_sampler is None: + return rows_per_shard + + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): if self.shuffle_level is None: @@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset): num_samples = self.num_samples num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) + rows_per_shard = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return rows_per_shard - return get_num_rows(num_rows, self.num_shards) + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): if self.shuffle_level is None: @@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset): Return: Number, number of batches. """ - return num_samples + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return self.num_samples + + return min(rows_from_sampler, self.num_samples) def is_shuffled(self): return True @@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset): Return: Number, number of batches. """ - return self.num_samples + rows_from_sampler = self._get_sampler_dataset_size() + + if rows_from_sampler is None: + return self.num_samples + + return min(rows_from_sampler, self.num_samples) def get_class_indexing(self): """ diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 8bf223251a..265d20f389 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -114,6 +114,9 @@ class Sampler: return self.child_sampler.is_sharded() + def get_dataset_size(self): + return self._get_indices().size + class BuiltinSampler: """ @@ -146,6 +149,12 @@ class BuiltinSampler: def is_sharded(self): raise NotImplementedError("Sampler must implement is_sharded.") + def get_dataset_size(self): + if self.child_sampler is not None: + return self.child_sampler.get_dataset_size() + + return None + class DistributedSampler(BuiltinSampler): """ @@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler): return self.child_sampler.is_sharded() + def get_dataset_size(self): + return self.num_samples + class SequentialSampler(BuiltinSampler): """ @@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler): return self.child_sampler.is_sharded() + def get_dataset_size(self): + return self.subset_size + class SubsetRandomSampler(BuiltinSampler): """ @@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler): return cde.MindrecordSubsetRandomSampler(self.indices) + def get_dataset_size(self): + return len(indices) + + class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). @@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler): return False return self.child_sampler.is_sharded() + + def get_dataset_size(self): + return self.num_samples diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 01a5f8d228..04b0602152 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -164,6 +165,35 @@ def test_python_sampler(): assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] +def test_subset_sampler(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + + def test_config(num_samples, start_index, subset_size): + sampler = ds.SubsetSampler(start_index, subset_size) + d = ds.ManifestDataset(manifest_file, sampler=sampler) + + res = [] + for item in d.create_dict_iterator(): + res.append(map[(item["image"].shape[0], item["label"].item())]) + + return res + + with pytest.raises(RuntimeError) as info: + test_config(5, 0, 0) + assert "subset_size <= 0" in str(info.value) + + assert test_config(5, 0, 1) == [0] + assert test_config(5, 0, 2) == [0, 1] + assert test_config(5, 0, 3) == [0, 1, 2] + assert test_config(5, 0, 4) == [0, 1, 2, 3] + assert test_config(5, 0, 5) == [0, 1, 2, 3, 4] + assert test_config(5, 1, 1) == [1] + assert test_config(5, 2, 3) == [2, 3, 4] + assert test_config(5, 3, 2) == [3, 4] + assert test_config(5, 4, 1) == [4] + + def test_sampler_chain(): manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} @@ -190,10 +220,26 @@ def test_sampler_chain(): assert test_config(5, 3) == [3] assert test_config(5, 4) == [4] +def test_add_sampler_invalid_input(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + data1 = ds.ManifestDataset(manifest_file) + + with pytest.raises(TypeError) as info: + data1.use_sampler(1) + assert "not an instance of a sampler" in str(info.value) + + with pytest.raises(TypeError) as info: + data1.use_sampler("sampler") + assert "not an instance of a sampler" in str(info.value) + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) test_sampler_py_api() test_python_sampler() + test_subset_sampler() test_sampler_chain() + test_add_sampler_invalid_input() diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index fa28b49181..0b546e5f6f 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -23,7 +23,11 @@ from util import config_get_set_num_parallel_workers manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} -def split_with_invalid_inputs(d): +text_file_dataset_path = "../data/dataset/testTextFileDataset/*" +text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", + "End of file.", "Good luck to everyone."] + +def split_with_invalid_inputs(d): with pytest.raises(ValueError) as info: s1, s2 = d.split([]) assert "sizes cannot be empty" in str(info.value) @@ -68,8 +72,8 @@ def split_with_invalid_inputs(d): s1, s2 = d.split([0.05, 0.95]) assert "percentage 0.05 is too small" in str(info.value) + def test_unmappable_invalid_input(): - text_file_dataset_path = "../data/dataset/testTextFileDataset/*" d = ds.TextFileDataset(text_file_dataset_path) split_with_invalid_inputs(d) @@ -78,11 +82,10 @@ def test_unmappable_invalid_input(): s1, s2 = d.split([4, 1]) assert "dataset should not be sharded before split" in str(info.value) + def test_unmappable_split(): - text_file_dataset_path = "../data/dataset/testTextFileDataset/*" - text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", - "End of file.", "Good luck to everyone."] original_num_parallel_workers = config_get_set_num_parallel_workers(4) + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) s1, s2 = d.split([4, 1], randomize=False) @@ -124,6 +127,142 @@ def test_unmappable_split(): assert s1_output == text_file_data[0:2] assert s2_output == text_file_data[2:] + + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_unmappable_randomize_deterministic(): + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + + # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] + ds.config.set_seed(53) + + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + for _ in range(10): + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + # note no overlap + assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] + assert s2_output == [text_file_data[3]] + + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_unmappable_randomize_repeatable(): + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + + # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] + ds.config.set_seed(53) + + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + num_epochs = 5 + s1 = s1.repeat(num_epochs) + s2 = s2.repeat(num_epochs) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + # note no overlap + assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs + assert s2_output == [text_file_data[3]] * num_epochs + + # Restore configuration num_parallel_workers + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_unmappable_get_dataset_size(): + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + assert d.get_dataset_size() == 5 + assert s1.get_dataset_size() == 4 + assert s2.get_dataset_size() == 1 + + +def test_unmappable_multi_split(): + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + + # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] + ds.config.set_seed(53) + + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([4, 1]) + + s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + assert s1_output == s1_correct_output + + # no randomize in second split + s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(item["text"].item().decode("utf8")) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(item["text"].item().decode("utf8")) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(item["text"].item().decode("utf8")) + + assert s1s1_output == [s1_correct_output[0]] + assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] + assert s1s3_output == [s1_correct_output[3]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + assert s2_output == [text_file_data[3]] + + # randomize in second split + # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0] + shuffled_ids = [2, 3, 1, 0] + + s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(item["text"].item().decode("utf8")) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(item["text"].item().decode("utf8")) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(item["text"].item().decode("utf8")) + + assert s1s1_output == [s1_correct_output[shuffled_ids[0]]] + assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]] + assert s1s3_output == [s1_correct_output[shuffled_ids[3]]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + assert s2_output == [text_file_data[3]] + # Restore configuration num_parallel_workers ds.config.set_num_parallel_workers(original_num_parallel_workers) @@ -137,6 +276,7 @@ def test_mappable_invalid_input(): s1, s2 = d.split([4, 1]) assert "dataset should not be sharded before split" in str(info.value) + def test_mappable_split_general(): d = ds.ManifestDataset(manifest_file, shuffle=False) d = d.take(5) @@ -183,6 +323,7 @@ def test_mappable_split_general(): assert s1_output == [0, 1] assert s2_output == [2, 3, 4] + def test_mappable_split_optimized(): d = ds.ManifestDataset(manifest_file, shuffle=False) @@ -228,9 +369,9 @@ def test_mappable_split_optimized(): assert s1_output == [0, 1] assert s2_output == [2, 3, 4] + def test_mappable_randomize_deterministic(): - # set arbitrary seed for shard after split - # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] ds.config.set_seed(53) d = ds.ManifestDataset(manifest_file, shuffle=False) @@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic(): assert s1_output == [0, 1, 3, 4] assert s2_output == [2] + def test_mappable_randomize_repeatable(): - # set arbitrary seed for shard after split - # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] ds.config.set_seed(53) d = ds.ManifestDataset(manifest_file, shuffle=False) @@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable(): assert s1_output == [0, 1, 3, 4] * num_epochs assert s2_output == [2] * num_epochs + def test_mappable_sharding(): # set arbitrary seed for repeatability for shard after split - # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] ds.config.set_seed(53) num_epochs = 5 @@ -336,12 +478,94 @@ def test_mappable_sharding(): assert s2_output == [2] assert d2s2_output == [2] + +def test_mappable_get_dataset_size(): + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([4, 1]) + + assert d.get_dataset_size() == 5 + assert s1.get_dataset_size() == 4 + assert s2.get_dataset_size() == 1 + + +def test_mappable_multi_split(): + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] + ds.config.set_seed(53) + + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([4, 1]) + + s1_correct_output = [0, 1, 3, 4] + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + assert s1_output == s1_correct_output + + # no randomize in second split + s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1s1_output == [s1_correct_output[0]] + assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] + assert s1s3_output == [s1_correct_output[3]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + assert s2_output == [2] + + # randomize in second split + # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0] + random_sampler_ids = [3, 1, 2, 0] + + s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) + + s1s1_output = [] + for item in s1s1.create_dict_iterator(): + s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s2_output = [] + for item in s1s2.create_dict_iterator(): + s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s1s3_output = [] + for item in s1s3.create_dict_iterator(): + s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]] + assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]] + assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]] + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + assert s2_output == [2] + + if __name__ == '__main__': test_unmappable_invalid_input() test_unmappable_split() + test_unmappable_randomize_deterministic() + test_unmappable_randomize_repeatable() + test_unmappable_get_dataset_size() + test_unmappable_multi_split() test_mappable_invalid_input() test_mappable_split_general() test_mappable_split_optimized() test_mappable_randomize_deterministic() test_mappable_randomize_repeatable() test_mappable_sharding() + test_mappable_get_dataset_size() + test_mappable_multi_split()