From 4b106f184c0eb816cce51066b9c30507b1d6aef4 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Wed, 10 Feb 2021 12:49:56 -0500 Subject: [PATCH] Fixed an issue with SequentialSampler --- .../datasetops/source/sampler/sequential_sampler.cc | 8 +++----- tests/ut/python/dataset/test_sampler_chain.py | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index 3a9ee21c3a..70ff27dc05 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -113,13 +113,11 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) { int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; // For this sampler we need to take start_index into account. Because for example in the case we are given n rows // and start_index != 0 and num_samples >= n then we can't return all the n rows. - if (child_num_rows - (start_index_ - current_id_) <= 0) { + if (child_num_rows - start_index_ <= 0) { return 0; } - if (child_num_rows - (start_index_ - current_id_) < num_samples) - num_samples = child_num_rows - (start_index_ - current_id_) > num_samples - ? num_samples - : num_samples - (start_index_ - current_id_); + if (child_num_rows - start_index_ < num_samples) + num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_; return num_samples; } diff --git a/tests/ut/python/dataset/test_sampler_chain.py b/tests/ut/python/dataset/test_sampler_chain.py index 6b68dcb097..8b4471d75d 100644 --- a/tests/ut/python/dataset/test_sampler_chain.py +++ b/tests/ut/python/dataset/test_sampler_chain.py @@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain(): # Use 1 statement to add child sampler np_data = [1, 2, 3, 4] sampler = ds.SequentialSampler(start_index=1, num_samples=2) - sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) + sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) # Verify dataset size data1_size = data1.get_dataset_size() logger.info("dataset size is: {}".format(data1_size)) - assert data1_size == 4 + assert data1_size == 1 # Verify number of rows - assert sum([1 for _ in data1]) == 4 + assert sum([1 for _ in data1]) == 1 # Verify dataset contents res = []