From 930e85ed5c4b05ac0a5e344079ececddc68863a0 Mon Sep 17 00:00:00 2001 From: liyong Date: Tue, 15 Sep 2020 22:41:55 +0800 Subject: [PATCH] fix get_dataset_size in distributedSampler & num_samples --- .../meta/shard_distributed_sample.cc | 6 ++- tests/ut/python/dataset/test_minddataset.py | 38 ++++++++++++++----- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc index 85300c6d67..7ffea62098 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -37,11 +37,13 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, boo int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (no_of_padded_samples_ <= 0) { + int64_t res = 0; if (dataset_size % denominator_ == 0) { - return dataset_size / denominator_ * numerator_; + res = dataset_size / denominator_ * numerator_; } else { - return dataset_size / denominator_ * numerator_ + 1; + res = dataset_size / denominator_ * numerator_ + 1; } + return no_of_samples_ == 0 ? res : std::min(static_cast(no_of_samples_), res); } else { auto padded_size = dataset_size + no_of_padded_samples_; if (padded_size % denominator_ == 0) { diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index a08fb40956..d201d4d4b7 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -278,6 +278,8 @@ def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file): data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, num_shards=num_shards, shard_id=partition_id, num_samples=1) + + assert data_set.get_dataset_size() == 1 num_iter = 0 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): logger.info("-------------- partition : {} ------------------------".format(partition_id)) @@ -301,6 +303,8 @@ def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file): data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, num_shards=num_shards, shard_id=partition_id, num_samples=2) + + assert data_set.get_dataset_size() == 2 num_iter = 0 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): logger.info("-------------- partition : {} ------------------------".format(partition_id)) @@ -319,11 +323,13 @@ def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file): columns_list = ["data", "file_name", "label"] num_readers = 4 - def partitions(num_shards): + def partitions(num_shards, expect): for partition_id in range(num_shards): data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, num_shards=num_shards, shard_id=partition_id, num_samples=3) + + assert data_set.get_dataset_size() == expect num_iter = 0 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): logger.info("-------------- partition : {} ------------------------".format(partition_id)) @@ -332,10 +338,25 @@ def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file): num_iter += 1 return num_iter - assert partitions(4) == 3 - assert partitions(5) == 2 - assert partitions(9) == 2 + assert partitions(4, 3) == 3 + assert partitions(5, 2) == 2 + assert partitions(9, 2) == 2 + +def test_cv_minddataset_partition_num_samples_3(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, num_shards=1, shard_id=0, num_samples=5) + + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + + assert num_iter == 5 def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file): """tutorial for cv minddataset.""" @@ -841,13 +862,10 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_ # define map operations decode_op = vision.Decode() - resize_op = vision.Resize( - (resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) + resize_op = vision.Resize((resize_height, resize_width)) - data_set = data_set.map( - input_columns=["data"], operations=decode_op, num_parallel_workers=4) - data_set = data_set.map( - input_columns=["data"], operations=resize_op, num_parallel_workers=4) + data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4) + data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4) data_set = data_set.batch(2) assert data_set.get_dataset_size() == 5