diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index b1f6e37eeb..18fa530dc0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -173,7 +173,8 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { child_num_rows = child_[0]->CalculateNumSamples(num_rows); } int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; - return std::ceil(num_samples * 1.0 / num_devices_); + int64_t num_per_shard = std::ceil(num_rows * 1.0 / num_devices_); + return std::min(num_samples, num_per_shard); } void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 1843b31656..0006a64444 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -811,6 +811,68 @@ TEST_F(MindDataTestPipeline, TestPipelineGetDatasetSize) { EXPECT_EQ(ds->GetDatasetSize(), 10); } +TEST_F(MindDataTestPipeline, TestDistributedGetDatasetSize1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedGetDatasetSize1."; + // Test get dataset size in distributed scenario when num_per_shard is more than num_samples + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, DistributedSampler(4, 0, false, 10)); + EXPECT_NE(ds, nullptr); + + // num_per_shard is equal to 44/4 = 11 which is more than num_samples = 10, so the output is 10 + EXPECT_EQ(ds->GetDatasetSize(), 10); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // iterate over the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + iter->GetNextRow(&row); + } + + // The value of i should be equal to the result of get dataset size + EXPECT_EQ(i, 10); +} + +TEST_F(MindDataTestPipeline, TestDistributedGetDatasetSize2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedGetDatasetSize2."; + // Test get dataset size in distributed scenario when num_per_shard is less than num_samples + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, DistributedSampler(4, 0, false, 15)); + EXPECT_NE(ds, nullptr); + + // num_per_shard is equal to 44/4 = 11 which is less than num_samples = 15, so the output is 11 + EXPECT_EQ(ds->GetDatasetSize(), 11); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // iterate over the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + iter->GetNextRow(&row); + } + + // The value of i should be equal to the result of get dataset size + EXPECT_EQ(i, 11); +} + TEST_F(MindDataTestPipeline, TestProjectMap) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectMap."; diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 01a3092f22..6836497189 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -87,7 +87,7 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { std::shared_ptr sampl = DistributedSampler(2, 1, false, 6); EXPECT_NE(sampl, nullptr); std::shared_ptr sampler_rt = sampl->Build(); - EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 3); + EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); sampl = PKSampler(3, false); EXPECT_NE(sampl, nullptr); diff --git a/tests/ut/python/dataset/test_datasets_get_dataset_size.py b/tests/ut/python/dataset/test_datasets_get_dataset_size.py index f0d3503cfb..415e916de2 100644 --- a/tests/ut/python/dataset/test_datasets_get_dataset_size.py +++ b/tests/ut/python/dataset/test_datasets_get_dataset_size.py @@ -75,6 +75,7 @@ def test_imagenet_tf_file_dataset_size(): count += 1 assert ds_shard_4_0.get_dataset_size() == count + def test_mnist_dataset_size(): ds_total = ds.MnistDataset(MNIST_DATA_DIR) assert ds_total.get_dataset_size() == 10000 @@ -247,6 +248,26 @@ def test_pipeline_get_dataset_size(): assert tf2.concat(tf1).get_dataset_size() == 24 +def test_distributed_get_dataset_size(): + # Test get dataset size when num_samples is less than num_per_shard (10000/4 = 2500) + dataset1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=2000, num_shards=4, shard_id=0) + assert dataset1.get_dataset_size() == 2000 + + count1 = 0 + for _ in dataset1.create_dict_iterator(): + count1 += 1 + assert count1 == 2000 + + # Test get dataset size when num_samples is more than num_per_shard (10000/4 = 2500) + dataset2 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3000, num_shards=4, shard_id=0) + assert dataset2.get_dataset_size() == 2500 + + count2 = 0 + for _ in dataset2.create_dict_iterator(): + count2 += 1 + assert count2 == 2500 + + if __name__ == '__main__': test_imagenet_rawdata_dataset_size() test_imagenet_tf_file_dataset_size() @@ -263,3 +284,4 @@ if __name__ == '__main__': test_text_file_dataset_size() test_padded_dataset_size() test_pipeline_get_dataset_size() + test_distributed_get_dataset_size()