From 5b4b5397510245e50d5d3dd1130fa6039dc8e23f Mon Sep 17 00:00:00 2001
From: jonyguo <guozhijian@huawei.com>
Date: Thu, 20 Aug 2020 22:32:12 +0800
Subject: [PATCH] fix: concat with none sample dataset

---
 mindspore/dataset/engine/datasets.py          | 21 ++++++----
 tests/ut/python/dataset/test_paddeddataset.py | 41 +++++++++++++++++++
 2 files changed, 54 insertions(+), 8 deletions(-)

diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py
index 3fa390b5e1..26a27448e6 100644
--- a/mindspore/dataset/engine/datasets.py
+++ b/mindspore/dataset/engine/datasets.py
@@ -2310,6 +2310,7 @@ class ConcatDataset(DatasetOp):
 
     Raises:
         TypeError: If dataset is not an instance of Dataset.
+        ValueError: If there is no samples in the one of the datasets.
     """
 
     def __init__(self, datasets):
@@ -2324,15 +2325,19 @@ class ConcatDataset(DatasetOp):
             data.parent.append(self)
 
         self.children_sizes_ = [c.get_dataset_size() for c in self.children]
-        """
-        _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
-        whether the data set is mappable. The second element of pair is length of the dataset
-        """
+        child_index = 0
+        for item in self.children_sizes_:
+            if item == 0:
+                raise ValueError("There is no samples in the %dth dataset. Please make sure there are "
+                                 "valid samples in the dataset" % child_index)
+            child_index += 1
+
+        # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
+        # whether the data set is mappable. The second element of pair is length of the dataset
         self._children_flag_and_nums = []
-        """
-         _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
-        the valid position of the dataset corresponding to the subscript when sampling
-        """
+
+        # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
+        # the valid position of the dataset corresponding to the subscript when sampling
         self._children_start_end_index_ = []
         for index, child in enumerate(self.children):
             tem_list = [-1, -1]
diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py
index 0c25b67b80..2d6efd4df6 100644
--- a/tests/ut/python/dataset/test_paddeddataset.py
+++ b/tests/ut/python/dataset/test_paddeddataset.py
@@ -1,4 +1,5 @@
 from io import BytesIO
+import copy
 import os
 import numpy as np
 import pytest
@@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file):
         result_list.append(tem_list)
     assert result_list == verify_list
 
+def test_clue_padded_and_skip_with_0_samples():
+    """
+    Test num_samples param of CLUE dataset
+    """
+    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
+
+    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
+    count = 0
+    for _ in data.create_dict_iterator():
+        count += 1
+    assert count == 3
+
+    data_copy1 = copy.deepcopy(data)
+
+    sample = {"label": np.array(1, np.string_),
+              "sentence1": np.array(1, np.string_),
+              "sentence2": np.array(1, np.string_)}
+    samples = [sample]
+    padded_ds = ds.PaddedDataset(samples)
+    dataset = data + padded_ds
+    testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
+    dataset.use_sampler(testsampler)
+    assert dataset.get_dataset_size() == 2
+    count = 0
+    for data in dataset.create_dict_iterator():
+        count += 1
+    assert count == 2
+
+    dataset = dataset.skip(count=2)    # dataset2 has none samples
+    count = 0
+    for data in dataset.create_dict_iterator():
+        count += 1
+    assert count == 0
+
+    with pytest.raises(ValueError, match="There is no samples in the "):
+        dataset = dataset.concat(data_copy1)
+        count = 0
+        for data in dataset.create_dict_iterator():
+            count += 1
+        assert count == 2
 
 if __name__ == '__main__':
     test_TFRecord_Padded()