|
|
|
@ -17,7 +17,10 @@ from __future__ import print_function
|
|
|
|
|
from .. import framework
|
|
|
|
|
import paddle.dataset.common
|
|
|
|
|
|
|
|
|
|
__all__ = ["Dataset", "IterableDataset", "TensorDataset"]
|
|
|
|
|
__all__ = [
|
|
|
|
|
"Dataset", "IterableDataset", "TensorDataset", "ComposeDataset",
|
|
|
|
|
"ChainDataset"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Dataset(object):
|
|
|
|
@ -275,3 +278,130 @@ class TensorDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return self.tensors[0].shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_list(value):
|
|
|
|
|
if value is None:
|
|
|
|
|
return value
|
|
|
|
|
if isinstance(value, (list, tuple)):
|
|
|
|
|
return list(value)
|
|
|
|
|
return [value]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ComposeDataset(Dataset):
|
|
|
|
|
"""
|
|
|
|
|
A Dataset which composes fields of multiple datasets.
|
|
|
|
|
|
|
|
|
|
This dataset is used for composing fileds of multiple map-style
|
|
|
|
|
datasets of same length.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
datasets(list of Dataset): List of datasets to be composed.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dataset: A Dataset which composes fields of multiple datasets.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.io import Dataset, ComposeDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# define a random dataset
|
|
|
|
|
class RandomDataset(Dataset):
|
|
|
|
|
def __init__(self, num_samples):
|
|
|
|
|
self.num_samples = num_samples
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
image = np.random.random([32]).astype('float32')
|
|
|
|
|
label = np.random.randint(0, 9, (1, )).astype('int64')
|
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return self.num_samples
|
|
|
|
|
|
|
|
|
|
dataset = ComposeDataset([RandomDataset(10), RandomDataset(10)])
|
|
|
|
|
for i in range(len(dataset)):
|
|
|
|
|
image1, label1, image2, label2 = dataset[i]
|
|
|
|
|
print(image1)
|
|
|
|
|
print(label1)
|
|
|
|
|
print(image2)
|
|
|
|
|
print(label2)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, datasets):
|
|
|
|
|
self.datasets = list(datasets)
|
|
|
|
|
assert len(self.datasets) > 0, "input datasets shoule not be empty"
|
|
|
|
|
for i, dataset in enumerate(self.datasets):
|
|
|
|
|
assert isinstance(dataset, Dataset), \
|
|
|
|
|
"each input dataset should be paddle.io.Dataset"
|
|
|
|
|
assert not isinstance(dataset, IterableDataset), \
|
|
|
|
|
"paddle.io.IterableDataset not supported"
|
|
|
|
|
if i > 0:
|
|
|
|
|
assert len(dataset) == len(self.datasets[i-1]), \
|
|
|
|
|
"lengths of datasets should be same"
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.datasets[0])
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
sample = []
|
|
|
|
|
for dataset in self.datasets:
|
|
|
|
|
sample.extend(to_list(dataset[idx]))
|
|
|
|
|
return tuple(sample)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChainDataset(IterableDataset):
|
|
|
|
|
"""
|
|
|
|
|
A Dataset which chains multiple iterable-tyle datasets.
|
|
|
|
|
|
|
|
|
|
This dataset is used for assembling multiple datasets which should
|
|
|
|
|
be :code:`paddle.io.IterableDataset`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
datasets(list of Dataset): List of datasets to be chainned.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dataset: A Dataset which chains fields of multiple datasets.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.io import IterableDataset, ChainDataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# define a random dataset
|
|
|
|
|
class RandomDataset(IterableDataset):
|
|
|
|
|
def __init__(self, num_samples):
|
|
|
|
|
self.num_samples = num_samples
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
for i in range(10):
|
|
|
|
|
image = np.random.random([32]).astype('float32')
|
|
|
|
|
label = np.random.randint(0, 9, (1, )).astype('int64')
|
|
|
|
|
yield image, label
|
|
|
|
|
|
|
|
|
|
dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
|
|
|
|
|
for image, label in iter(dataset):
|
|
|
|
|
print(image, label)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, datasets):
|
|
|
|
|
self.datasets = list(datasets)
|
|
|
|
|
assert len(self.datasets) > 0, "input datasets shoule not be empty"
|
|
|
|
|
for i, dataset in enumerate(self.datasets):
|
|
|
|
|
assert isinstance(dataset, IterableDataset), \
|
|
|
|
|
"ChainDataset only support paddle.io.IterableDataset"
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
for dataset in self.datasets:
|
|
|
|
|
for sample in dataset:
|
|
|
|
|
yield sample
|
|
|
|
|