diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc index ecff8ae5d3..ade1d496f5 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc @@ -35,33 +35,33 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c if (per_ > kEpsilon && per_ <= 1.0f) { return dataset_size * kEpsilon; } - return no_of_samples_; + return std::min(static_cast(no_of_samples_), dataset_size); } MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { - int total_no = static_cast(tasks.Size()); - int taking; + int64_t total_no = static_cast(tasks.Size()); + int64_t taking; if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { taking = total_no; } else if (per_ > kEpsilon && per_ <= 1.0f) { taking = total_no * kEpsilon; } else { - taking = no_of_samples_; + taking = std::min(static_cast(no_of_samples_), total_no); } if (tasks.permutation_.empty()) { ShardTask new_tasks; - total_no = static_cast(tasks.Size()); - for (int i = offset_; i < taking + offset_; ++i) { + total_no = static_cast(tasks.Size()); + for (size_t i = offset_; i < taking + offset_; ++i) { new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); } std::swap(tasks, new_tasks); } else { // shuffled ShardTask new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { + if (taking > static_cast(tasks.permutation_.size())) { return FAILED; } - total_no = static_cast(tasks.permutation_.size()); + total_no = static_cast(tasks.permutation_.size()); for (size_t i = offset_; i < taking + offset_; ++i) { new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc index 7743cabea3..3dcf9c9526 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -39,7 +39,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (replacement_) { return no_of_samples_ == 0 ? dataset_size : no_of_samples_; } - return dataset_size; + return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); } MSRStatus ShardShuffle::Execute(ShardTask &tasks) { @@ -67,6 +67,14 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { std::swap(tasks, new_tasks); } else { std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + auto total_no = static_cast(tasks.Size()); + if (no_of_samples_ > 0 && no_of_samples_ < total_no) { + ShardTask new_tasks; + for (size_t i = 0; i < no_of_samples_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(i)); + } + std::swap(tasks, new_tasks); + } } } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) uint32_t individual_size = tasks.Size() / tasks.categories; diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 3c1503b030..3aa8eb4859 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -311,7 +311,7 @@ class Unique(cde.UniqueOp): Call batch op before calling this function. Examples: - >>> import mindspore.dataset.transforms.c_transforms as c_transforms + >>> import mindspore.dataset.transforms.c_transforms as c_transforms >>> >>> # Data before >>> # | x | diff --git a/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc b/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc index 47923b6bc1..f5c8578623 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc @@ -208,7 +208,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) { std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0"; std::vector file_list = {file_path1}; - // Check sequential sampler, output number is 10, with duplicate samples(a little weird, wait to fix) + // Check sequential sampler, output number is 5 std::shared_ptr ds1 = MindData(file_list, {}, SequentialSampler(0, 10)); EXPECT_NE(ds1, nullptr); @@ -229,7 +229,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) { EXPECT_NE(ds5, nullptr); std::vector> ds = {ds1, ds2, ds3, ds4, ds5}; - std::vector expected_samples = {10, 5, 2, 3, 3}; + std::vector expected_samples = {5, 5, 2, 3, 3}; for (int32_t i = 0; i < ds.size(); i++) { // Create an iterator over the result of the above dataset diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index d096b3fde0..40183592e7 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -412,6 +412,46 @@ def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file): num_iter += 1 assert num_iter == 5 +def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.RandomSampler(replacement=False, num_samples=2) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 2 + num_iter = 0 + for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 2 + +def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file): + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.RandomSampler(replacement=False, num_samples=20) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 10 + def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): data = get_data(CV_DIR_NAME, True) @@ -437,7 +477,7 @@ def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): assert num_iter == 4 -def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): +def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file): data = get_data(CV_DIR_NAME, True) columns_list = ["data", "file_name", "label"] num_readers = 4 @@ -461,6 +501,30 @@ def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): num_iter += 1 assert num_iter == 10 +def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): + data = get_data(CV_DIR_NAME, True) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.SequentialSampler(2, 20) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + dataset_size = data_set.get_dataset_size() + assert dataset_size == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info( + "-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info( + "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + assert item['file_name'] == np.array( + data[(num_iter + 2) % dataset_size]['file_name'], dtype='S') + num_iter += 1 + assert num_iter == 10 + def test_cv_minddataset_split_basic(add_and_remove_cv_file): data = get_data(CV_DIR_NAME, True)