From 5b4b5397510245e50d5d3dd1130fa6039dc8e23f Mon Sep 17 00:00:00 2001 From: jonyguo <guozhijian@huawei.com> Date: Thu, 20 Aug 2020 22:32:12 +0800 Subject: [PATCH] fix: concat with none sample dataset --- mindspore/dataset/engine/datasets.py | 21 ++++++---- tests/ut/python/dataset/test_paddeddataset.py | 41 +++++++++++++++++++ 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3fa390b5e1..26a27448e6 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2310,6 +2310,7 @@ class ConcatDataset(DatasetOp): Raises: TypeError: If dataset is not an instance of Dataset. + ValueError: If there is no samples in the one of the datasets. """ def __init__(self, datasets): @@ -2324,15 +2325,19 @@ class ConcatDataset(DatasetOp): data.parent.append(self) self.children_sizes_ = [c.get_dataset_size() for c in self.children] - """ - _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes - whether the data set is mappable. The second element of pair is length of the dataset - """ + child_index = 0 + for item in self.children_sizes_: + if item == 0: + raise ValueError("There is no samples in the %dth dataset. Please make sure there are " + "valid samples in the dataset" % child_index) + child_index += 1 + + # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes + # whether the data set is mappable. The second element of pair is length of the dataset self._children_flag_and_nums = [] - """ - _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize - the valid position of the dataset corresponding to the subscript when sampling - """ + + # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize + # the valid position of the dataset corresponding to the subscript when sampling self._children_start_end_index_ = [] for index, child in enumerate(self.children): tem_list = [-1, -1] diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index 0c25b67b80..2d6efd4df6 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -1,4 +1,5 @@ from io import BytesIO +import copy import os import numpy as np import pytest @@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file): result_list.append(tem_list) assert result_list == verify_list +def test_clue_padded_and_skip_with_0_samples(): + """ + Test num_samples param of CLUE dataset + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') + count = 0 + for _ in data.create_dict_iterator(): + count += 1 + assert count == 3 + + data_copy1 = copy.deepcopy(data) + + sample = {"label": np.array(1, np.string_), + "sentence1": np.array(1, np.string_), + "sentence2": np.array(1, np.string_)} + samples = [sample] + padded_ds = ds.PaddedDataset(samples) + dataset = data + padded_ds + testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) + dataset.use_sampler(testsampler) + assert dataset.get_dataset_size() == 2 + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + assert count == 2 + + dataset = dataset.skip(count=2) # dataset2 has none samples + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + assert count == 0 + + with pytest.raises(ValueError, match="There is no samples in the "): + dataset = dataset.concat(data_copy1) + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + assert count == 2 if __name__ == '__main__': test_TFRecord_Padded()