|
|
|
@ -44,7 +44,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|
|
|
|
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
|
|
|
|
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
|
|
|
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
|
|
|
|
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset
|
|
|
|
|
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\
|
|
|
|
|
check_paddeddataset
|
|
|
|
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
|
|
|
|
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
|
|
|
|
|
|
|
|
@ -2305,10 +2306,35 @@ class ConcatDataset(DatasetOp):
|
|
|
|
|
if not isinstance(dataset, Dataset):
|
|
|
|
|
raise TypeError("The parameter %s of concat has type error!" % (dataset))
|
|
|
|
|
self.datasets = datasets
|
|
|
|
|
self._sampler = None
|
|
|
|
|
for data in datasets:
|
|
|
|
|
self.children.append(data)
|
|
|
|
|
data.parent.append(self)
|
|
|
|
|
|
|
|
|
|
self.children_sizes_ = [c.get_dataset_size() for c in self.children]
|
|
|
|
|
"""
|
|
|
|
|
_children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
|
|
|
|
|
whether the data set is mappable. The second element of pair is length of the dataset
|
|
|
|
|
"""
|
|
|
|
|
self._children_flag_and_nums = []
|
|
|
|
|
"""
|
|
|
|
|
_children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
|
|
|
|
|
the valid position of the dataset corresponding to the subscript when sampling
|
|
|
|
|
"""
|
|
|
|
|
self._children_start_end_index_ = []
|
|
|
|
|
for index, child in enumerate(self.children):
|
|
|
|
|
tem_list = [-1, -1]
|
|
|
|
|
self._children_start_end_index_.append(tem_list)
|
|
|
|
|
datasetLen = self.children_sizes_[index]
|
|
|
|
|
if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"):
|
|
|
|
|
datasetLen = 0
|
|
|
|
|
self.children_sizes_[index] = 0
|
|
|
|
|
|
|
|
|
|
if isinstance(child, MappableDataset):
|
|
|
|
|
self._children_flag_and_nums.append((0, datasetLen))
|
|
|
|
|
else:
|
|
|
|
|
self._children_flag_and_nums.append((1, datasetLen))
|
|
|
|
|
|
|
|
|
|
def get_dataset_size(self):
|
|
|
|
|
"""
|
|
|
|
|
Get the number of batches in an epoch.
|
|
|
|
@ -2321,6 +2347,67 @@ class ConcatDataset(DatasetOp):
|
|
|
|
|
self.dataset_size = sum(children_sizes)
|
|
|
|
|
return self.dataset_size
|
|
|
|
|
|
|
|
|
|
def use_sampler(self, sampler):
|
|
|
|
|
"""
|
|
|
|
|
Set the distributedSampler to concat dataset
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sampler (Sampler): the sampler to use for the current dataset. Current support: DistributedSampler.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If the sampler is not an istance of DistributedSampler
|
|
|
|
|
ValueError: If the parameter shuffle of sampler is True
|
|
|
|
|
ValueError: If the parameter NumSamples of sampler is not None.
|
|
|
|
|
ValueError: If num_shards <=0.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(sampler, samplers.DistributedSampler):
|
|
|
|
|
raise TypeError("The parameter %s of concat should be DistributedSampler!" % (sampler))
|
|
|
|
|
|
|
|
|
|
if sampler.is_shuffled():
|
|
|
|
|
raise ValueError("The parameter shuffle of DistributedSampler is not support to be true!")
|
|
|
|
|
|
|
|
|
|
if sampler.num_shards <= 0:
|
|
|
|
|
raise ValueError("The parameter num_shards of concat should be positive int!")
|
|
|
|
|
|
|
|
|
|
if sampler.get_num_samples() is not None:
|
|
|
|
|
raise ValueError("The parameter NumSamples of DistributedSampler is not support to be set!")
|
|
|
|
|
|
|
|
|
|
self._sampler = _select_sampler(None, sampler, None, None, None)
|
|
|
|
|
cumulative_samples_nums = 0
|
|
|
|
|
for index, child in enumerate(self.children):
|
|
|
|
|
|
|
|
|
|
if isinstance(child, BatchDataset):
|
|
|
|
|
raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child))
|
|
|
|
|
|
|
|
|
|
if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]:
|
|
|
|
|
|
|
|
|
|
tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1]
|
|
|
|
|
|
|
|
|
|
if not self._children_flag_and_nums[index][1] >= sampler.num_shards:
|
|
|
|
|
if tem_value < sampler.num_shards:
|
|
|
|
|
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
|
|
|
|
self._children_start_end_index_[index][1] = tem_value
|
|
|
|
|
else:
|
|
|
|
|
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
|
|
|
|
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tem_sampler = copy.deepcopy(sampler)
|
|
|
|
|
tem_sampler.set_offset(cumulative_samples_nums)
|
|
|
|
|
child.sampler = tem_sampler
|
|
|
|
|
|
|
|
|
|
cumulative_samples_nums += self.children_sizes_[index]
|
|
|
|
|
cumulative_samples_nums %= sampler.num_shards
|
|
|
|
|
|
|
|
|
|
def get_args(self):
|
|
|
|
|
args = super().get_args()
|
|
|
|
|
|
|
|
|
|
if self._sampler is not None:
|
|
|
|
|
args["sampler"] = self._sampler
|
|
|
|
|
args["children_flag_and_nums"] = self._children_flag_and_nums
|
|
|
|
|
args["children_start_end_index"] = self._children_start_end_index_
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RenameDataset(DatasetOp):
|
|
|
|
|
"""
|
|
|
|
@ -3307,7 +3394,6 @@ class GeneratorDataset(MappableDataset):
|
|
|
|
|
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
|
|
|
|
new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
|
|
|
|
|
new_op.dataset_size = self.dataset_size
|
|
|
|
|
|
|
|
|
|
new_op.sampler = copy.deepcopy(self.sampler)
|
|
|
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
|
|
|
|
if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
|
|
|
@ -5276,6 +5362,53 @@ class NumpySlicesDataset(GeneratorDataset):
|
|
|
|
|
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
|
|
|
|
num_shards=num_shards, shard_id=shard_id)
|
|
|
|
|
|
|
|
|
|
class _PaddedDataset:
|
|
|
|
|
"""
|
|
|
|
|
Mainly for combining false samples provided by users into a dataset.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
padded_samples (list(dict)): the data provided by user to added to initial Dataset
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, padded_samples):
|
|
|
|
|
self.column_names = list(padded_samples[0].keys())
|
|
|
|
|
self.padded_samples = padded_samples
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
|
return (self.padded_samples[item][key] for key in self.column_names)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.padded_samples)
|
|
|
|
|
|
|
|
|
|
class PaddedDataset(GeneratorDataset):
|
|
|
|
|
"""
|
|
|
|
|
Create a dataset with fake data provided by user. Mainly used to add to the original data set
|
|
|
|
|
and assign it to the corresponding shard.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
padded_samples (list(dict)): the samples provided by user
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If padded_samples is not an instance of list.
|
|
|
|
|
TypeError: If the element of padded_samples is not an instance of dict.
|
|
|
|
|
ValueError: If the padded_samples is empty.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import mindspore.dataset as ds
|
|
|
|
|
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
|
|
|
|
|
>>> ds1 = ds.PaddedDataset(data1)
|
|
|
|
|
"""
|
|
|
|
|
@check_paddeddataset
|
|
|
|
|
def __init__(self, padded_samples):
|
|
|
|
|
dataset = _PaddedDataset(padded_samples)
|
|
|
|
|
super().__init__(dataset, column_names=dataset.column_names,
|
|
|
|
|
num_shards=None,
|
|
|
|
|
shard_id=None, shuffle=False)
|
|
|
|
|
self._dataset_size = len(dataset.padded_samples)
|
|
|
|
|
self.padded_samples = padded_samples
|
|
|
|
|
|
|
|
|
|
def get_dataset_size(self):
|
|
|
|
|
return self._dataset_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BuildVocabDataset(DatasetOp):
|
|
|
|
|
"""
|
|
|
|
|