|
|
|
@ -38,7 +38,27 @@ from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCH
|
|
|
|
|
MP_INDICES_CHECK_INTERVAL = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _default_collate_fn(batch):
|
|
|
|
|
def default_collate_fn(batch):
|
|
|
|
|
"""
|
|
|
|
|
Default batch collating function for :code:`fluid.io.DataLoader`,
|
|
|
|
|
batch should be a list of samples, and each sample should be a list
|
|
|
|
|
of fields as follows:
|
|
|
|
|
|
|
|
|
|
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
|
|
|
|
|
|
|
|
|
|
This default collate function zipped each filed together and stack
|
|
|
|
|
each filed as the batch field as follows:
|
|
|
|
|
|
|
|
|
|
[batch_filed1, batch_filed2, ...]
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
batch(list of list of numpy array): the batch data, each fields
|
|
|
|
|
should be a numpy array, each sample should be a list of
|
|
|
|
|
fileds, and batch should be a list of sample.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
a list of numpy array: collated batch
|
|
|
|
|
"""
|
|
|
|
|
sample = batch[0]
|
|
|
|
|
# dataset has only 1 field
|
|
|
|
|
if isinstance(sample, np.ndarray):
|
|
|
|
@ -82,7 +102,7 @@ class _DataLoaderIterBase(object):
|
|
|
|
|
self._return_list = loader.return_list
|
|
|
|
|
self._batch_sampler = loader.batch_sampler
|
|
|
|
|
self._sampler_iter = iter(loader.batch_sampler)
|
|
|
|
|
self._collate_fn = loader.collate_fn or _default_collate_fn
|
|
|
|
|
self._collate_fn = loader.collate_fn or default_collate_fn
|
|
|
|
|
self._num_workers = loader.num_workers
|
|
|
|
|
self._use_buffer_reader = loader.use_buffer_reader
|
|
|
|
|
self._use_shared_memory = loader.use_shared_memory
|
|
|
|
|