|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
import copy
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file):
|
|
|
|
|
result_list.append(tem_list)
|
|
|
|
|
assert result_list == verify_list
|
|
|
|
|
|
|
|
|
|
def test_clue_padded_and_skip_with_0_samples():
|
|
|
|
|
"""
|
|
|
|
|
Test num_samples param of CLUE dataset
|
|
|
|
|
"""
|
|
|
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
|
|
|
|
|
|
|
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
|
|
|
|
count = 0
|
|
|
|
|
for _ in data.create_dict_iterator():
|
|
|
|
|
count += 1
|
|
|
|
|
assert count == 3
|
|
|
|
|
|
|
|
|
|
data_copy1 = copy.deepcopy(data)
|
|
|
|
|
|
|
|
|
|
sample = {"label": np.array(1, np.string_),
|
|
|
|
|
"sentence1": np.array(1, np.string_),
|
|
|
|
|
"sentence2": np.array(1, np.string_)}
|
|
|
|
|
samples = [sample]
|
|
|
|
|
padded_ds = ds.PaddedDataset(samples)
|
|
|
|
|
dataset = data + padded_ds
|
|
|
|
|
testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
|
|
|
|
|
dataset.use_sampler(testsampler)
|
|
|
|
|
assert dataset.get_dataset_size() == 2
|
|
|
|
|
count = 0
|
|
|
|
|
for data in dataset.create_dict_iterator():
|
|
|
|
|
count += 1
|
|
|
|
|
assert count == 2
|
|
|
|
|
|
|
|
|
|
dataset = dataset.skip(count=2) # dataset2 has none samples
|
|
|
|
|
count = 0
|
|
|
|
|
for data in dataset.create_dict_iterator():
|
|
|
|
|
count += 1
|
|
|
|
|
assert count == 0
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="There is no samples in the "):
|
|
|
|
|
dataset = dataset.concat(data_copy1)
|
|
|
|
|
count = 0
|
|
|
|
|
for data in dataset.create_dict_iterator():
|
|
|
|
|
count += 1
|
|
|
|
|
assert count == 2
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_TFRecord_Padded()
|
|
|
|
|