You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							343 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							343 lines
						
					
					
						
							12 KiB
						
					
					
				| # Copyright 2020 Huawei Technologies Co., Ltd
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| # http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| # ==============================================================================
 | |
| import pytest
 | |
| import mindspore.dataset as ds
 | |
| 
 | |
| # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
 | |
| # the label of each image is [0,0,0,1,1] each image can be uniquely identified
 | |
| # via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
 | |
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
 | |
| manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
 | |
| 
 | |
| def split_with_invalid_inputs(d): 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([])
 | |
|     assert "sizes cannot be empty" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([5, 0.6])
 | |
|     assert "sizes should be list of int or list of float" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([-1, 6])
 | |
|     assert "there should be no negative numbers" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(RuntimeError) as info:
 | |
|         s1, s2 = d.split([3, 1])
 | |
|     assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(RuntimeError) as info:
 | |
|         s1, s2 = d.split([5, 1])
 | |
|     assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(RuntimeError) as info:
 | |
|         s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
 | |
|     assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([-0.5, 0.5])
 | |
|     assert "there should be no numbers outside the range [0, 1]" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([1.5, 0.5])
 | |
|     assert "there should be no numbers outside the range [0, 1]" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([0.5, 0.6])
 | |
|     assert "percentages do not sum up to 1" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(ValueError) as info:
 | |
|         s1, s2 = d.split([0.3, 0.6])
 | |
|     assert "percentages do not sum up to 1" in str(info.value)
 | |
| 
 | |
|     with pytest.raises(RuntimeError) as info:
 | |
|         s1, s2 = d.split([0.05, 0.95])
 | |
|     assert "percentage 0.05 is too small" in str(info.value)
 | |
| 
 | |
| def test_unmappable_invalid_input():
 | |
|     text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
 | |
|     d = ds.TextFileDataset(text_file_dataset_path)
 | |
|     split_with_invalid_inputs(d)
 | |
| 
 | |
|     d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
 | |
|     with pytest.raises(RuntimeError) as info:
 | |
|         s1, s2 = d.split([4, 1])
 | |
|     assert "dataset should not be sharded before split" in str(info.value)
 | |
| 
 | |
| def test_unmappable_split():
 | |
|     text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
 | |
|     text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
 | |
|             "End of file.", "Good luck to everyone."]
 | |
|     ds.config.set_num_parallel_workers(4)
 | |
|     d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
 | |
|     s1, s2 = d.split([4, 1], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(item["text"].item().decode("utf8"))
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(item["text"].item().decode("utf8"))
 | |
| 
 | |
|     assert s1_output == text_file_data[0:4]
 | |
|     assert s2_output == text_file_data[4:]
 | |
| 
 | |
|     # exact percentages
 | |
|     s1, s2 = d.split([0.8, 0.2], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(item["text"].item().decode("utf8"))
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(item["text"].item().decode("utf8"))
 | |
| 
 | |
|     assert s1_output == text_file_data[0:4]
 | |
|     assert s2_output == text_file_data[4:]
 | |
| 
 | |
|     # fuzzy percentages
 | |
|     s1, s2 = d.split([0.33, 0.67], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(item["text"].item().decode("utf8"))
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(item["text"].item().decode("utf8"))
 | |
| 
 | |
|     assert s1_output == text_file_data[0:2]
 | |
|     assert s2_output == text_file_data[2:]
 | |
| 
 | |
| def test_mappable_invalid_input():
 | |
|     d = ds.ManifestDataset(manifest_file)
 | |
|     split_with_invalid_inputs(d)
 | |
| 
 | |
|     d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
 | |
|     with pytest.raises(RuntimeError) as info:
 | |
|         s1, s2 = d.split([4, 1])
 | |
|     assert "dataset should not be sharded before split" in str(info.value)
 | |
| 
 | |
| def test_mappable_split_general():
 | |
|     d = ds.ManifestDataset(manifest_file, shuffle=False)
 | |
|     d = d.take(5)
 | |
| 
 | |
|     # absolute rows
 | |
|     s1, s2 = d.split([4, 1], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s1_output == [0, 1, 2, 3]
 | |
|     assert s2_output == [4]
 | |
| 
 | |
|     # exact percentages
 | |
|     s1, s2 = d.split([0.8, 0.2], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s1_output == [0, 1, 2, 3]
 | |
|     assert s2_output == [4]
 | |
| 
 | |
|     # fuzzy percentages
 | |
|     s1, s2 = d.split([0.33, 0.67], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s1_output == [0, 1]
 | |
|     assert s2_output == [2, 3, 4]
 | |
| 
 | |
| def test_mappable_split_optimized():
 | |
|     d = ds.ManifestDataset(manifest_file, shuffle=False)
 | |
| 
 | |
|     # absolute rows
 | |
|     s1, s2 = d.split([4, 1], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s1_output == [0, 1, 2, 3]
 | |
|     assert s2_output == [4]
 | |
| 
 | |
|     # exact percentages
 | |
|     s1, s2 = d.split([0.8, 0.2], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s1_output == [0, 1, 2, 3]
 | |
|     assert s2_output == [4]
 | |
| 
 | |
|     # fuzzy percentages
 | |
|     s1, s2 = d.split([0.33, 0.67], randomize=False)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s1_output == [0, 1]
 | |
|     assert s2_output == [2, 3, 4]
 | |
| 
 | |
| def test_mappable_randomize_deterministic():
 | |
|     # set arbitrary seed for shard after split
 | |
|     # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
 | |
|     ds.config.set_seed(53)
 | |
| 
 | |
|     d = ds.ManifestDataset(manifest_file, shuffle=False)
 | |
|     s1, s2 = d.split([0.8, 0.2])
 | |
| 
 | |
|     for _ in range(10):
 | |
|         s1_output = []
 | |
|         for item in s1.create_dict_iterator():
 | |
|             s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|         s2_output = []
 | |
|         for item in s2.create_dict_iterator():
 | |
|             s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|         # note no overlap
 | |
|         assert s1_output == [0, 1, 3, 4]
 | |
|         assert s2_output == [2]
 | |
| 
 | |
| def test_mappable_randomize_repeatable():
 | |
|     # set arbitrary seed for shard after split
 | |
|     # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
 | |
|     ds.config.set_seed(53)
 | |
| 
 | |
|     d = ds.ManifestDataset(manifest_file, shuffle=False)
 | |
|     s1, s2 = d.split([0.8, 0.2])
 | |
| 
 | |
|     num_epochs = 5
 | |
|     s1 = s1.repeat(num_epochs)
 | |
|     s2 = s2.repeat(num_epochs)
 | |
| 
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     # note no overlap
 | |
|     assert s1_output == [0, 1, 3, 4] * num_epochs
 | |
|     assert s2_output == [2] * num_epochs
 | |
| 
 | |
| def test_mappable_sharding():
 | |
|     # set arbitrary seed for repeatability for shard after split
 | |
|     # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
 | |
|     ds.config.set_seed(53)
 | |
| 
 | |
|     num_epochs = 5
 | |
|     first_split_num_rows = 4
 | |
| 
 | |
|     d = ds.ManifestDataset(manifest_file, shuffle=False)
 | |
|     s1, s2 = d.split([first_split_num_rows, 1])
 | |
| 
 | |
|     distributed_sampler = ds.DistributedSampler(2, 0)
 | |
|     s1.use_sampler(distributed_sampler)
 | |
| 
 | |
|     s1 = s1.repeat(num_epochs)
 | |
| 
 | |
|     # testing sharding, second dataset to simulate another instance
 | |
|     d2 = ds.ManifestDataset(manifest_file, shuffle=False)
 | |
|     d2s1, d2s2 = d2.split([first_split_num_rows, 1])
 | |
| 
 | |
|     distributed_sampler = ds.DistributedSampler(2, 1)
 | |
|     d2s1.use_sampler(distributed_sampler)
 | |
| 
 | |
|     d2s1 = d2s1.repeat(num_epochs)
 | |
| 
 | |
|     # shard 0
 | |
|     s1_output = []
 | |
|     for item in s1.create_dict_iterator():
 | |
|         s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     # shard 1
 | |
|     d2s1_output = []
 | |
|     for item in d2s1.create_dict_iterator():
 | |
|         d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     rows_per_shard_per_epoch = 2
 | |
|     assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
 | |
|     assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
 | |
| 
 | |
|     # verify each epoch that
 | |
|     #   1. shards contain no common elements
 | |
|     #   2. the data was split the same way, and that the union of shards equal the split
 | |
|     correct_sorted_split_result = [0, 1, 3, 4]
 | |
|     for i in range(num_epochs):
 | |
|         combined_data = []
 | |
|         for j in range(rows_per_shard_per_epoch):
 | |
|             combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
 | |
|             combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
 | |
| 
 | |
|         assert sorted(combined_data) == correct_sorted_split_result
 | |
| 
 | |
|     # test other split
 | |
|     s2_output = []
 | |
|     for item in s2.create_dict_iterator():
 | |
|         s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     d2s2_output = []
 | |
|     for item in d2s2.create_dict_iterator():
 | |
|         d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
 | |
| 
 | |
|     assert s2_output == [2]
 | |
|     assert d2s2_output == [2]
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_unmappable_invalid_input()
 | |
|     test_unmappable_split()
 | |
|     test_mappable_invalid_input()
 | |
|     test_mappable_split_general()
 | |
|     test_mappable_split_optimized()
 | |
|     test_mappable_randomize_deterministic()
 | |
|     test_mappable_randomize_repeatable()
 | |
|     test_mappable_sharding()
 |