Refine DataLoader support multi-processing (#23107)
* add DataLoader, Dataset, BatchSamplerrevert-22778-infer_var_type
parent
76d78c6387
commit
80cf3c3c4d
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from . import dataset
|
||||||
|
from .dataset import *
|
||||||
|
|
||||||
|
from . import batch_sampler
|
||||||
|
from .batch_sampler import *
|
||||||
|
|
||||||
|
__all__ = dataset.__all__ \
|
||||||
|
+ batch_sampler.__all__
|
@ -0,0 +1,143 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .dataset import Dataset
|
||||||
|
|
||||||
|
__all__ = ["BatchSampler"]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSampler(object):
|
||||||
|
"""
|
||||||
|
A base implement of batch sampler used by `paddle.io.DataLoader`
|
||||||
|
which yield mini-batch indices(a list/tuple with length as
|
||||||
|
mini-batch size and holds sample indices) iterably.
|
||||||
|
|
||||||
|
Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass
|
||||||
|
of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should
|
||||||
|
implement following methods:
|
||||||
|
|
||||||
|
:code:`__iter__`: return mini-batch indices iterably.
|
||||||
|
|
||||||
|
:code:`__len__`: get mini-batch number in an epoch.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset(Dataset): this could be a :code:`paddle.io.Dataset`
|
||||||
|
implement or other python object which implemented
|
||||||
|
:code:`__len__` for BatchSampler to get indices as the
|
||||||
|
range of :attr:`dataset` length. Default None.
|
||||||
|
indices (list|tuple): a substitution parameter for
|
||||||
|
:attr:`dataset` either :attr:`dataset` or
|
||||||
|
:attr:`indices` should be set, give the whole
|
||||||
|
indices to sampler from directly. Default None.
|
||||||
|
shuffle(bool): whether to shuffle indices order before genrating
|
||||||
|
batch indices. Default False.
|
||||||
|
batch_size(int): sample indice number in a mini-batch indices.
|
||||||
|
drop_last(bool): whether drop the last incomplete batch dataset size
|
||||||
|
is not divisible by the batch size. Default False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchSampler: an iterable object for indices iterating
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from paddle.io import BatchSampler, Dataset
|
||||||
|
|
||||||
|
# init with indices
|
||||||
|
bs = BatchSampler(indices=list(range(100)),
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=8,
|
||||||
|
drop_last=True)
|
||||||
|
|
||||||
|
for batch_indices in bs:
|
||||||
|
print(batch_indices)
|
||||||
|
|
||||||
|
# init with dataset
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, num_samples):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
image = np.random.random([784]).astype('float32')
|
||||||
|
label = np.random.randint(0, 9, (1, )).astype('int64')
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
bs = BatchSampler(dataset=RandomDataset(100),
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=16,
|
||||||
|
drop_last=False)
|
||||||
|
|
||||||
|
for batch_indices in bs:
|
||||||
|
print(batch_indices)
|
||||||
|
|
||||||
|
see `paddle.io.DataLoader`
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset=None,
|
||||||
|
indices=None,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
drop_last=False):
|
||||||
|
if dataset is None:
|
||||||
|
assert indices is not None, \
|
||||||
|
"either dataset or indices should be set"
|
||||||
|
assert isinstance(indices, list) or isinstance(indices, tuple), \
|
||||||
|
"indices should be a list or tuple, but got {}".format(type(indices))
|
||||||
|
self.indices = indices
|
||||||
|
else:
|
||||||
|
assert isinstance(dataset, Dataset), \
|
||||||
|
"dataset should be an instance of paddle.io.Dataset"
|
||||||
|
assert indices is None, \
|
||||||
|
"should not set both dataset and indices"
|
||||||
|
self.indices = list(range(len(dataset)))
|
||||||
|
|
||||||
|
assert isinstance(batch_size, int) and batch_size > 0, \
|
||||||
|
"batch_size should be a positive integer, but got {}".format(batch_size)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
assert isinstance(shuffle, bool), \
|
||||||
|
"shuffle should be a boolean value, but got {}".format(type(shuffle))
|
||||||
|
self.shuffle = shuffle
|
||||||
|
assert isinstance(drop_last, bool), \
|
||||||
|
"drop_last should be a boolean value, but got {}".format(type(drop_last))
|
||||||
|
self.drop_last = drop_last
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.shuffle:
|
||||||
|
np.random.shuffle(self.indices)
|
||||||
|
_iter = iter(self.indices)
|
||||||
|
|
||||||
|
batch_indices = []
|
||||||
|
for idx in _iter:
|
||||||
|
batch_indices.append(idx)
|
||||||
|
if len(batch_indices) == self.batch_size:
|
||||||
|
yield batch_indices
|
||||||
|
batch_indices = []
|
||||||
|
if not self.drop_last and len(batch_indices) > 0:
|
||||||
|
yield batch_indices
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
num_samples = len(self.indices)
|
||||||
|
num_samples += int(not self.drop_last) * (self.batch_size - 1)
|
||||||
|
return num_samples // self.batch_size
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle.dataset.common
|
||||||
|
|
||||||
|
__all__ = ["Dataset"]
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(object):
|
||||||
|
"""
|
||||||
|
An abstract class to encapsulates methods and behaviors of datasets.
|
||||||
|
|
||||||
|
All datasets in map-style(dataset samples can be get by a given key)
|
||||||
|
should be a subclass of `paddle.io.Dataset`. All subclasses should
|
||||||
|
implement following methods:
|
||||||
|
|
||||||
|
:code:`__getitem__`: get sample from dataset with a given index. This
|
||||||
|
method is required by reading dataset sample in :code:`paddle.io.DataLoader`.
|
||||||
|
|
||||||
|
:code:`__len__`: return dataset sample number. This method is required
|
||||||
|
by some implements of :code:`paddle.io.BatchSampler`
|
||||||
|
|
||||||
|
see :code:`paddle.io.DataLoader`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
# define a random dataset
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, num_samples):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
image = np.random.random([784]).astype('float32')
|
||||||
|
label = np.random.randint(0, 9, (1, )).astype('int64')
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
dataset = RandomDataset(10)
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
print(dataset[i])
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
raise NotImplementedError("'{}' not implement in class "\
|
||||||
|
"{}".format('__getitem__', self.__class__.__name__))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError("'{}' not implement in class "\
|
||||||
|
"{}".format('__len__', self.__class__.__name__))
|
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
from . import core
|
||||||
|
|
||||||
|
# NOTE: queue has a different name in python2 and python3
|
||||||
|
if six.PY2:
|
||||||
|
import Queue as queue
|
||||||
|
else:
|
||||||
|
import queue
|
||||||
|
|
||||||
|
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
|
||||||
|
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
|
||||||
|
# from the child process will automatically clear the memory-mapped file.
|
||||||
|
multiprocess_queue_set = set()
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_multiprocess_queue_set():
|
||||||
|
global multiprocess_queue_set
|
||||||
|
for data_queue in multiprocess_queue_set:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: main process clear function at exit
|
||||||
|
def _cleanup():
|
||||||
|
# NOTE: inter-process Queue shared memory objects clear function
|
||||||
|
_clear_multiprocess_queue_set()
|
||||||
|
# NOTE: main process memory map files clear funciton
|
||||||
|
core._cleanup_mmap_fds()
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: for child process clear function at exit
|
||||||
|
def _cleanup_mmap():
|
||||||
|
# clear memory map files in child process
|
||||||
|
core._cleanup_mmap_fds()
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE used for register a function to be executed at interpreter exit.
|
||||||
|
class CleanupFuncRegistrar():
|
||||||
|
# Record the cleanup functions that have been executed
|
||||||
|
_executed_func_set = set()
|
||||||
|
# Record the cleanup functions that have been registered
|
||||||
|
_registered_func_set = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, function, signals=[]):
|
||||||
|
def _func_exectuor():
|
||||||
|
if function not in cls._executed_func_set:
|
||||||
|
try:
|
||||||
|
function()
|
||||||
|
finally:
|
||||||
|
cls._executed_func_set.add(function)
|
||||||
|
|
||||||
|
def _func_register(function):
|
||||||
|
if not callable(function):
|
||||||
|
raise TypeError("%s is not callable object." % (function))
|
||||||
|
# check function object whether hash-able
|
||||||
|
set([function])
|
||||||
|
if function not in cls._registered_func_set:
|
||||||
|
atexit.register(_func_exectuor)
|
||||||
|
cls._registered_func_set.add(function)
|
||||||
|
|
||||||
|
def _signal_handler(signum=None, frame=None):
|
||||||
|
_func_exectuor()
|
||||||
|
if signum is not None:
|
||||||
|
if signum == signal.SIGINT:
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
sys.exit(signum)
|
||||||
|
|
||||||
|
def _signal_register(signals):
|
||||||
|
signals = set(signals)
|
||||||
|
for sig in signals:
|
||||||
|
orig_handler = signal.signal(sig, _signal_handler)
|
||||||
|
if orig_handler not in (signal.SIG_DFL, signal.SIG_IGN):
|
||||||
|
if (sig == signal.SIGINT and
|
||||||
|
orig_handler is signal.default_int_handler):
|
||||||
|
continue
|
||||||
|
if orig_handler not in cls._registered_func_set:
|
||||||
|
atexit.register(orig_handler)
|
||||||
|
cls._registered_func_set.add(orig_handler)
|
||||||
|
|
||||||
|
# deal with signals
|
||||||
|
_signal_register(signals)
|
||||||
|
# deal with function
|
||||||
|
_func_register(function)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: [ mmap files clear ] When the main process exits unexpectedly, the remaining
|
||||||
|
# shared memory objects in the inter-process Queue and the main process (mostly in the
|
||||||
|
# BlockingQueue) may not be completely released, resulting in the corresponding
|
||||||
|
# memory-mapped file remaining on the disk (/dev/shm), so register this function
|
||||||
|
# to clean up shared memory objects in these two queues before the python interpreter exits.
|
||||||
|
# NOTE: Currently multi-process DataLoader only supports Linux platform
|
||||||
|
if not (sys.platform == 'darwin' or sys.platform == 'win32'):
|
||||||
|
CleanupFuncRegistrar.register(_cleanup)
|
||||||
|
|
||||||
|
# ------------ SIGCHLD handler setting --------------
|
||||||
|
_SIGCHLD_handler_set = False
|
||||||
|
|
||||||
|
|
||||||
|
def _set_SIGCHLD_handler():
|
||||||
|
global _SIGCHLD_handler_set
|
||||||
|
if _SIGCHLD_handler_set:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_handler = signal.getsignal(signal.SIGCHLD)
|
||||||
|
if not callable(current_handler):
|
||||||
|
current_handler = None
|
||||||
|
|
||||||
|
def __handler__(signum, frame):
|
||||||
|
# NOTE: Here the signum is SIGCHLD, when the child process exits,
|
||||||
|
# this handler will be called whenever the child process exits
|
||||||
|
# normally or abnormally.
|
||||||
|
core._throw_error_if_process_failed()
|
||||||
|
if current_handler is not None:
|
||||||
|
current_handler(signum, frame)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGCHLD, __handler__)
|
||||||
|
_SIGCHLD_handler_set = True
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,120 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.io import BatchSampler, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, sample_num, class_num):
|
||||||
|
self.sample_num = sample_num
|
||||||
|
self.class_num = class_num
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
np.random.seed(idx)
|
||||||
|
image = np.random.random([IMAGE_SIZE]).astype('float32')
|
||||||
|
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.sample_num
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchSampler(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.num_samples = 1000
|
||||||
|
self.num_classes = 10
|
||||||
|
self.batch_size = 32
|
||||||
|
self.shuffle = False
|
||||||
|
self.drop_last = False
|
||||||
|
|
||||||
|
def init_batch_sampler(self):
|
||||||
|
dataset = RandomDataset(self.num_samples, self.num_classes)
|
||||||
|
bs = BatchSampler(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=self.shuffle,
|
||||||
|
drop_last=self.drop_last)
|
||||||
|
return bs
|
||||||
|
|
||||||
|
def test_main(self):
|
||||||
|
bs = self.init_batch_sampler()
|
||||||
|
# length check
|
||||||
|
bs_len = (self.num_samples + int(not self.drop_last) \
|
||||||
|
* (self.batch_size - 1)) // self.batch_size
|
||||||
|
self.assertTrue(bs_len == len(bs))
|
||||||
|
|
||||||
|
# output indices check
|
||||||
|
if not self.shuffle:
|
||||||
|
index = 0
|
||||||
|
for indices in bs:
|
||||||
|
for idx in indices:
|
||||||
|
self.assertTrue(index == idx)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchSamplerDropLast(TestBatchSampler):
|
||||||
|
def setUp(self):
|
||||||
|
self.num_samples = 1000
|
||||||
|
self.num_classes = 10
|
||||||
|
self.batch_size = 32
|
||||||
|
self.shuffle = False
|
||||||
|
self.drop_last = True
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchSamplerShuffle(TestBatchSampler):
|
||||||
|
def setUp(self):
|
||||||
|
self.num_samples = 1000
|
||||||
|
self.num_classes = 10
|
||||||
|
self.batch_size = 32
|
||||||
|
self.shuffle = True
|
||||||
|
self.drop_last = True
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchSamplerWithIndices(TestBatchSampler):
|
||||||
|
def init_batch_sampler(self):
|
||||||
|
bs = BatchSampler(
|
||||||
|
indices=list(range(self.num_samples)),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
drop_last=self.drop_last)
|
||||||
|
return bs
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchSamplerWithIndicesAndDataSource(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.num_samples = 1000
|
||||||
|
self.num_classes = 10
|
||||||
|
self.batch_size = 32
|
||||||
|
self.shuffle = False
|
||||||
|
self.drop_last = True
|
||||||
|
|
||||||
|
def test_main(self):
|
||||||
|
try:
|
||||||
|
dataset = RandomDataset(self.num_samples, self.num_classes)
|
||||||
|
bs = BatchSampler(
|
||||||
|
dataset=dataset,
|
||||||
|
indices=list(range(self.num_samples)),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
drop_last=self.drop_last)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.io import *
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetAbstract(unittest.TestCase):
|
||||||
|
def test_main(self):
|
||||||
|
dataset = Dataset()
|
||||||
|
try:
|
||||||
|
d = dataset[0]
|
||||||
|
self.assertTrue(False)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
l = len(dataset)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import six
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
import multiprocessing
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.io import Dataset, BatchSampler, DataLoader
|
||||||
|
from paddle.fluid.dygraph.nn import Linear
|
||||||
|
from paddle.fluid.dygraph.base import to_variable
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, sample_num):
|
||||||
|
self.sample_num = sample_num
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
np.random.seed(idx)
|
||||||
|
image = np.random.random([784]).astype('float32')
|
||||||
|
label = np.random.randint(0, 9, (1, )).astype('int64')
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.sample_num
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataLoaderAssert(unittest.TestCase):
|
||||||
|
def test_main(self):
|
||||||
|
place = fluid.cpu_places()[0]
|
||||||
|
with fluid.dygraph.guard(place):
|
||||||
|
dataset = RandomDataset(100)
|
||||||
|
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
|
||||||
|
|
||||||
|
# dataset is not instance of Dataset
|
||||||
|
try:
|
||||||
|
loader = DataLoader(dataset=batch_sampler, places=place)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# places is None
|
||||||
|
try:
|
||||||
|
loader = DataLoader(dataset=dataset, places=None)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# num_workers < 0
|
||||||
|
try:
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset=dataset, places=place, num_workers=-1)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# timeout < 0
|
||||||
|
try:
|
||||||
|
loader = DataLoader(dataset=dataset, places=place, timeout=-1)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# batch_sampler is not instance of BatchSampler
|
||||||
|
try:
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset=dataset, places=place, batch_sampler=dataset)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# set batch_sampler and shuffle/batch_size/drop_last
|
||||||
|
try:
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
places=place,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# set batch_sampler correctly
|
||||||
|
try:
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset=dataset, places=place, batch_sampler=batch_sampler)
|
||||||
|
self.assertTrue(True)
|
||||||
|
except AssertionError:
|
||||||
|
self.assertTrue(False)
|
||||||
|
|
||||||
|
|
||||||
|
# CI Converage cannot record stub in subprocess,
|
||||||
|
# HACK a _worker_loop in main process call here
|
||||||
|
class TestDataLoaderWorkerLoop(unittest.TestCase):
|
||||||
|
def run_without_worker_done(self, use_shared_memory=True):
|
||||||
|
try:
|
||||||
|
place = fluid.cpu_places()[0]
|
||||||
|
with fluid.dygraph.guard(place):
|
||||||
|
dataset = RandomDataset(800)
|
||||||
|
|
||||||
|
# test init_fn
|
||||||
|
def _init_fn(worker_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# test collate_fn
|
||||||
|
def _collate_fn(sample_list):
|
||||||
|
return [
|
||||||
|
np.stack(
|
||||||
|
s, axis=0) for s in list(zip(*sample_list))
|
||||||
|
]
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=1,
|
||||||
|
places=place,
|
||||||
|
use_shared_memory=use_shared_memory)
|
||||||
|
assert loader.num_workers > 0, \
|
||||||
|
"go to AssertionError and pass in Mac and Windows"
|
||||||
|
loader = iter(loader)
|
||||||
|
print("loader length", len(loader))
|
||||||
|
indices_queue = multiprocessing.Queue()
|
||||||
|
for i in range(10):
|
||||||
|
indices_queue.put([i, i + 10])
|
||||||
|
indices_queue.put(None)
|
||||||
|
loader._worker_loop(
|
||||||
|
loader._dataset, indices_queue, loader._data_queue,
|
||||||
|
loader._workers_done_event, _collate_fn, _init_fn, 0)
|
||||||
|
self.assertTrue(False)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
self.assertTrue(False)
|
||||||
|
|
||||||
|
def run_with_worker_done(self, use_shared_memory=True):
|
||||||
|
try:
|
||||||
|
place = fluid.cpu_places()[0]
|
||||||
|
with fluid.dygraph.guard(place):
|
||||||
|
dataset = RandomDataset(800)
|
||||||
|
|
||||||
|
# test init_fn
|
||||||
|
def _init_fn(worker_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# test collate_fn
|
||||||
|
def _collate_fn(sample_list):
|
||||||
|
return [
|
||||||
|
np.stack(
|
||||||
|
s, axis=0) for s in list(zip(*sample_list))
|
||||||
|
]
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=1,
|
||||||
|
places=place,
|
||||||
|
use_shared_memory=use_shared_memory)
|
||||||
|
assert loader.num_workers > 0, \
|
||||||
|
"go to AssertionError and pass in Mac and Windows"
|
||||||
|
loader = iter(loader)
|
||||||
|
print("loader length", len(loader))
|
||||||
|
indices_queue = multiprocessing.Queue()
|
||||||
|
for i in range(10):
|
||||||
|
indices_queue.put([i, i + 10])
|
||||||
|
indices_queue.put(None)
|
||||||
|
loader._workers_done_event.set()
|
||||||
|
loader._worker_loop(
|
||||||
|
loader._dataset, indices_queue, loader._data_queue,
|
||||||
|
loader._workers_done_event, _collate_fn, _init_fn, 0)
|
||||||
|
self.assertTrue(True)
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
self.assertTrue(False)
|
||||||
|
|
||||||
|
def test_main(self):
|
||||||
|
for use_shared_memory in [True, False]:
|
||||||
|
self.run_without_worker_done(use_shared_memory)
|
||||||
|
self.run_with_worker_done(use_shared_memory)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue