fix: concat with none sample dataset

pull/4868/head
jonyguo 5 years ago
parent b69b1ca8a8
commit 5b4b539751

@ -2310,6 +2310,7 @@ class ConcatDataset(DatasetOp):
Raises:
TypeError: If dataset is not an instance of Dataset.
ValueError: If there is no samples in the one of the datasets.
"""
def __init__(self, datasets):
@ -2324,15 +2325,19 @@ class ConcatDataset(DatasetOp):
data.parent.append(self)
self.children_sizes_ = [c.get_dataset_size() for c in self.children]
"""
_children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
whether the data set is mappable. The second element of pair is length of the dataset
"""
child_index = 0
for item in self.children_sizes_:
if item == 0:
raise ValueError("There is no samples in the %dth dataset. Please make sure there are "
"valid samples in the dataset" % child_index)
child_index += 1
# _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
# whether the data set is mappable. The second element of pair is length of the dataset
self._children_flag_and_nums = []
"""
_children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
the valid position of the dataset corresponding to the subscript when sampling
"""
# _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
# the valid position of the dataset corresponding to the subscript when sampling
self._children_start_end_index_ = []
for index, child in enumerate(self.children):
tem_list = [-1, -1]

@ -1,4 +1,5 @@
from io import BytesIO
import copy
import os
import numpy as np
import pytest
@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file):
result_list.append(tem_list)
assert result_list == verify_list
def test_clue_padded_and_skip_with_0_samples():
"""
Test num_samples param of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 3
data_copy1 = copy.deepcopy(data)
sample = {"label": np.array(1, np.string_),
"sentence1": np.array(1, np.string_),
"sentence2": np.array(1, np.string_)}
samples = [sample]
padded_ds = ds.PaddedDataset(samples)
dataset = data + padded_ds
testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
dataset.use_sampler(testsampler)
assert dataset.get_dataset_size() == 2
count = 0
for data in dataset.create_dict_iterator():
count += 1
assert count == 2
dataset = dataset.skip(count=2) # dataset2 has none samples
count = 0
for data in dataset.create_dict_iterator():
count += 1
assert count == 0
with pytest.raises(ValueError, match="There is no samples in the "):
dataset = dataset.concat(data_copy1)
count = 0
for data in dataset.create_dict_iterator():
count += 1
assert count == 2
if __name__ == '__main__':
test_TFRecord_Padded()

Loading…
Cancel
Save