From 6df1a43ec27d20989fd4857453f0b205c2157ac1 Mon Sep 17 00:00:00 2001 From: jonyguo Date: Tue, 25 Aug 2020 22:43:21 +0800 Subject: [PATCH] fix: padded dataset with non div & repeat --- .../source/sampler/distributed_sampler.cc | 3 +++ tests/ut/python/dataset/test_paddeddataset.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index bff155a7c8..407cb0ac22 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -75,6 +75,9 @@ Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); } else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + if (!samples_per_buffer_) { + non_empty_ = false; + } } else if (!samples_per_buffer_ && !non_empty_) { // If the buffer is empty, we add samples with subscript 0 in the current dataset. // This step is to make up for the solution that the code default buffer is not empty before. diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index 2d6efd4df6..4fc4eea5a8 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples(): count += 1 assert count == 2 +def test_celeba_padded(): + data = ds.CelebADataset("../data/dataset/testCelebAData/") + + padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}] + padded_ds = ds.PaddedDataset(padded_samples) + data = data + padded_ds + dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) + data.use_sampler(dis_sampler) + data = data.repeat(2) + + count = 0 + for _ in data.create_dict_iterator(): + count = count + 1 + assert count == 2 + if __name__ == '__main__': test_TFRecord_Padded() test_GeneratorDataSet_Padded()