|
|
|
@ -33,7 +33,7 @@ import copy
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
|
|
|
|
|
MindRecordOp, TextFileOp, VOCOp, CocoOp, CBatchInfo
|
|
|
|
|
MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo
|
|
|
|
|
from mindspore._c_expression import typing
|
|
|
|
|
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|
|
|
|
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
|
|
|
|
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
|
|
|
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
|
|
|
|
check_split
|
|
|
|
|
check_split, check_cluedataset
|
|
|
|
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
@ -4317,6 +4317,222 @@ class CelebADataset(MappableDataset):
|
|
|
|
|
return self.sampler.is_sharded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLUEDataset(SourceDataset):
|
|
|
|
|
"""
|
|
|
|
|
A source dataset that reads and parses CLUE datasets.
|
|
|
|
|
CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained
|
|
|
|
|
models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK,
|
|
|
|
|
CMNLI, WSC and CSL.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
|
|
|
|
|
files. The list will be sorted in a lexicographical order.
|
|
|
|
|
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
|
|
|
|
|
(default=AFQMC).
|
|
|
|
|
usage (str, optional): Need train, test or eval data (default="train").
|
|
|
|
|
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
|
|
|
|
num_parallel_workers (int, optional): number of workers to read the data
|
|
|
|
|
(default=None, number set in the config).
|
|
|
|
|
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
|
|
|
|
|
If shuffle is False, no shuffling will be performed;
|
|
|
|
|
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
|
|
|
|
Otherwise, there are two levels of shuffling:
|
|
|
|
|
|
|
|
|
|
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
|
|
|
|
|
|
|
|
|
- Shuffle.FILES: Shuffle files only.
|
|
|
|
|
|
|
|
|
|
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
|
|
|
|
|
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
|
|
|
|
argument should be specified only when num_shards is also specified.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import mindspore.dataset as ds
|
|
|
|
|
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
|
|
|
|
|
>>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train')
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@check_cluedataset
|
|
|
|
|
def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
|
|
|
|
|
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
|
|
|
|
super().__init__(num_parallel_workers)
|
|
|
|
|
self.dataset_files = self._find_files(dataset_files)
|
|
|
|
|
self.dataset_files.sort()
|
|
|
|
|
self.num_samples = num_samples
|
|
|
|
|
self.task_dict = {
|
|
|
|
|
'AFQMC': {
|
|
|
|
|
'train': {
|
|
|
|
|
'sentence1': 'sentence1',
|
|
|
|
|
'sentence2': 'sentence2',
|
|
|
|
|
'label': 'label'
|
|
|
|
|
},
|
|
|
|
|
'test': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'sentence1': 'sentence1',
|
|
|
|
|
'sentence2': 'sentence2'
|
|
|
|
|
},
|
|
|
|
|
'eval': {
|
|
|
|
|
'sentence1': 'sentence1',
|
|
|
|
|
'sentence2': 'sentence2',
|
|
|
|
|
'label': 'label'
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'CMNLI': {
|
|
|
|
|
'train': {
|
|
|
|
|
'sentence1': 'sentence1',
|
|
|
|
|
'sentence2': 'sentence2',
|
|
|
|
|
'label': 'label'
|
|
|
|
|
},
|
|
|
|
|
'test': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'sentence1': 'sentence1',
|
|
|
|
|
'sentence2': 'sentence2'
|
|
|
|
|
},
|
|
|
|
|
'eval': {
|
|
|
|
|
'sentence1': 'sentence1',
|
|
|
|
|
'sentence2': 'sentence2',
|
|
|
|
|
'label': 'label'
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'CSL': {
|
|
|
|
|
'train': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'abst': 'abst',
|
|
|
|
|
'keyword': 'keyword',
|
|
|
|
|
'label': 'label'
|
|
|
|
|
},
|
|
|
|
|
'test': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'abst': 'abst',
|
|
|
|
|
'keyword': 'keyword'
|
|
|
|
|
},
|
|
|
|
|
'eval': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'abst': 'abst',
|
|
|
|
|
'keyword': 'keyword',
|
|
|
|
|
'label': 'label'
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'IFLYTEK': {
|
|
|
|
|
'train': {
|
|
|
|
|
'label': 'label',
|
|
|
|
|
'label_des': 'label_des',
|
|
|
|
|
'sentence': 'sentence'
|
|
|
|
|
},
|
|
|
|
|
'test': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'sentence': 'sentence',
|
|
|
|
|
},
|
|
|
|
|
'eval': {
|
|
|
|
|
'label': 'label',
|
|
|
|
|
'label_des': 'label_des',
|
|
|
|
|
'sentence': 'sentence'
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'TNEWS': {
|
|
|
|
|
'train': {
|
|
|
|
|
'label': 'label',
|
|
|
|
|
'label_desc': 'label_desc',
|
|
|
|
|
'sentence': 'sentence',
|
|
|
|
|
'keywords': 'keywords'
|
|
|
|
|
},
|
|
|
|
|
'test': {
|
|
|
|
|
'id': 'id',
|
|
|
|
|
'sentence': 'sentence',
|
|
|
|
|
'keywords': 'keywords'
|
|
|
|
|
},
|
|
|
|
|
'eval': {
|
|
|
|
|
'label': 'label',
|
|
|
|
|
'label_desc': 'label_desc',
|
|
|
|
|
'sentence': 'sentence',
|
|
|
|
|
'keywords': 'keywords'
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'WSC': {
|
|
|
|
|
'train': {
|
|
|
|
|
'span1_index': 'target/span1_index',
|
|
|
|
|
'span2_index': 'target/span2_index',
|
|
|
|
|
'span1_text': 'target/span1_text',
|
|
|
|
|
'span2_text': 'target/span2_text',
|
|
|
|
|
'idx': 'idx',
|
|
|
|
|
'label': 'label',
|
|
|
|
|
'text': 'text'
|
|
|
|
|
},
|
|
|
|
|
'test': {
|
|
|
|
|
'span1_index': 'target/span1_index',
|
|
|
|
|
'span2_index': 'target/span2_index',
|
|
|
|
|
'span1_text': 'target/span1_text',
|
|
|
|
|
'span2_text': 'target/span2_text',
|
|
|
|
|
'idx': 'idx',
|
|
|
|
|
'text': 'text'
|
|
|
|
|
},
|
|
|
|
|
'eval': {
|
|
|
|
|
'span1_index': 'target/span1_index',
|
|
|
|
|
'span2_index': 'target/span2_index',
|
|
|
|
|
'span1_text': 'target/span1_text',
|
|
|
|
|
'span2_text': 'target/span2_text',
|
|
|
|
|
'idx': 'idx',
|
|
|
|
|
'label': 'label',
|
|
|
|
|
'text': 'text'
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
self.cols_to_keyword = self.task_dict[task][usage]
|
|
|
|
|
|
|
|
|
|
if not isinstance(shuffle, (bool, Shuffle)):
|
|
|
|
|
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
|
|
|
|
|
if not isinstance(shuffle, Shuffle):
|
|
|
|
|
if shuffle:
|
|
|
|
|
self.shuffle_level = Shuffle.GLOBAL
|
|
|
|
|
self.shuffle_files = True
|
|
|
|
|
else:
|
|
|
|
|
self.shuffle_level = None
|
|
|
|
|
self.shuffle_files = False
|
|
|
|
|
else:
|
|
|
|
|
self.shuffle_level = shuffle
|
|
|
|
|
self.shuffle_files = True
|
|
|
|
|
|
|
|
|
|
self.num_shards = num_shards
|
|
|
|
|
self.shard_id = shard_id
|
|
|
|
|
|
|
|
|
|
def get_args(self):
|
|
|
|
|
args = super().get_args()
|
|
|
|
|
args["dataset_files"] = self.dataset_files
|
|
|
|
|
args["num_samples"] = self.num_samples
|
|
|
|
|
if self.shuffle_files is not None:
|
|
|
|
|
args["shuffle_files"] = self.shuffle_files
|
|
|
|
|
args["shuffle"] = self.shuffle_level
|
|
|
|
|
args["num_shards"] = self.num_shards
|
|
|
|
|
args["shard_id"] = self.shard_id
|
|
|
|
|
args["cols_to_keyword"] = self.cols_to_keyword
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
def get_dataset_size(self):
|
|
|
|
|
"""
|
|
|
|
|
Get the number of batches in an epoch.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
"""
|
|
|
|
|
if self._dataset_size is None:
|
|
|
|
|
num_rows = ClueOp.get_num_rows(self.dataset_files)
|
|
|
|
|
num_rows = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
if self.num_samples is None:
|
|
|
|
|
return num_rows
|
|
|
|
|
return min(self.num_samples, num_rows)
|
|
|
|
|
return self._dataset_size
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
|
return self.shuffle_files
|
|
|
|
|
|
|
|
|
|
def is_sharded(self):
|
|
|
|
|
if self.num_shards is not None:
|
|
|
|
|
return self.num_shards > 1
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextFileDataset(SourceDataset):
|
|
|
|
|
"""
|
|
|
|
|
A source dataset that reads and parses datasets stored on disk in text format.
|
|
|
|
|