From 96859f44b67d90e2f57ce71198230cb99a8f3b3b Mon Sep 17 00:00:00 2001 From: jonyguo Date: Thu, 11 Jun 2020 09:15:05 +0800 Subject: [PATCH] fix: MindDataset distribute shuffle bug --- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 16 ++- tests/ut/python/dataset/test_minddataset.py | 133 ++++++++++++++++++ 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index e5b18c1f9c..9d6ea969ea 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_reader.h" #include "common/utils.h" @@ -1385,9 +1386,18 @@ void ShardReader::Reset() { void ShardReader::ShuffleTask() { for (const auto &op : operators_) { - if (block_reader_ || !std::dynamic_pointer_cast(op)) continue; - if (SUCCESS != (*op)(tasks_)) { - MS_LOG(WARNING) << "Reshuffle reader tasks failed."; + if (block_reader_) { + continue; + } + + if (std::dynamic_pointer_cast(op)) { + if (SUCCESS != (*op)(tasks_)) { + MS_LOG(WARNING) << "Reshuffle reader tasks failed."; + } + } else if (std::dynamic_pointer_cast(op)) { + if (SUCCESS != op->PreExecute(tasks_)) { + MS_LOG(WARNING) << "Distribute reshuffle reader tasks failed."; + } } } } diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 991bdf71a1..986fc6b665 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -238,6 +238,139 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): assert partitions(9) == 2 +def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + num_shards = 3 + epoch1 = [] + epoch2 = [] + epoch3 = [] + + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, shard_id=partition_id) + + data_set = data_set.repeat(3) + + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + if num_iter <= 4: + epoch1.append(item["file_name"]) # save epoch 1 list + elif num_iter <= 8: + epoch2.append(item["file_name"]) # save epoch 2 list + else: + epoch3.append(item["file_name"]) # save epoch 3 list + assert num_iter == 12 + assert len(epoch1) == 4 + assert len(epoch2) == 4 + assert len(epoch3) == 4 + assert epoch1 not in (epoch2, epoch3) + assert epoch2 not in (epoch1, epoch3) + assert epoch3 not in (epoch1, epoch2) + epoch1 = [] + epoch2 = [] + epoch3 = [] + + +def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + ds.config.set_seed(54321) + epoch1 = [] + epoch2 = [] + epoch3 = [] + + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) + data_set = data_set.repeat(3) + + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + if num_iter <= 10: + epoch1.append(item["file_name"]) # save epoch 1 list + elif num_iter <= 20: + epoch2.append(item["file_name"]) # save epoch 2 list + else: + epoch3.append(item["file_name"]) # save epoch 3 list + assert num_iter == 30 + assert len(epoch1) == 10 + assert len(epoch2) == 10 + assert len(epoch3) == 10 + assert epoch1 not in (epoch2, epoch3) + assert epoch2 not in (epoch1, epoch3) + assert epoch3 not in (epoch1, epoch2) + + epoch1_new_dataset = [] + epoch2_new_dataset = [] + epoch3_new_dataset = [] + + data_set2 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) + data_set2 = data_set2.repeat(3) + + num_iter = 0 + for item in data_set2.create_dict_iterator(): + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + if num_iter <= 10: + epoch1_new_dataset.append(item["file_name"]) # save epoch 1 list + elif num_iter <= 20: + epoch2_new_dataset.append(item["file_name"]) # save epoch 2 list + else: + epoch3_new_dataset.append(item["file_name"]) # save epoch 3 list + assert num_iter == 30 + assert len(epoch1_new_dataset) == 10 + assert len(epoch2_new_dataset) == 10 + assert len(epoch3_new_dataset) == 10 + assert epoch1_new_dataset not in (epoch2_new_dataset, epoch3_new_dataset) + assert epoch2_new_dataset not in (epoch1_new_dataset, epoch3_new_dataset) + assert epoch3_new_dataset not in (epoch1_new_dataset, epoch2_new_dataset) + + assert epoch1 == epoch1_new_dataset + assert epoch2 == epoch2_new_dataset + assert epoch3 == epoch3_new_dataset + + ds.config.set_seed(12345) + epoch1_new_dataset2 = [] + epoch2_new_dataset2 = [] + epoch3_new_dataset2 = [] + + data_set3 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) + data_set3 = data_set3.repeat(3) + + num_iter = 0 + for item in data_set3.create_dict_iterator(): + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + if num_iter <= 10: + epoch1_new_dataset2.append(item["file_name"]) # save epoch 1 list + elif num_iter <= 20: + epoch2_new_dataset2.append(item["file_name"]) # save epoch 2 list + else: + epoch3_new_dataset2.append(item["file_name"]) # save epoch 3 list + assert num_iter == 30 + assert len(epoch1_new_dataset2) == 10 + assert len(epoch2_new_dataset2) == 10 + assert len(epoch3_new_dataset2) == 10 + assert epoch1_new_dataset2 not in (epoch2_new_dataset2, epoch3_new_dataset2) + assert epoch2_new_dataset2 not in (epoch1_new_dataset2, epoch3_new_dataset2) + assert epoch3_new_dataset2 not in (epoch1_new_dataset2, epoch2_new_dataset2) + + assert epoch1 != epoch1_new_dataset2 + assert epoch2 != epoch2_new_dataset2 + assert epoch3 != epoch3_new_dataset2 + + def test_cv_minddataset_dataset_size(add_and_remove_cv_file): """tutorial for cv minddataset.""" columns_list = ["data", "file_name", "label"]