Refine DataLoader support multi-processing (#23107)
* add DataLoader, Dataset, BatchSamplerrevert-22778-infer_var_type
parent
76d78c6387
commit
80cf3c3c4d
@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from . import dataset
|
||||
from .dataset import *
|
||||
|
||||
from . import batch_sampler
|
||||
from .batch_sampler import *
|
||||
|
||||
__all__ = dataset.__all__ \
|
||||
+ batch_sampler.__all__
|
@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
from .dataset import Dataset
|
||||
|
||||
__all__ = ["BatchSampler"]
|
||||
|
||||
|
||||
class BatchSampler(object):
|
||||
"""
|
||||
A base implement of batch sampler used by `paddle.io.DataLoader`
|
||||
which yield mini-batch indices(a list/tuple with length as
|
||||
mini-batch size and holds sample indices) iterably.
|
||||
|
||||
Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass
|
||||
of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should
|
||||
implement following methods:
|
||||
|
||||
:code:`__iter__`: return mini-batch indices iterably.
|
||||
|
||||
:code:`__len__`: get mini-batch number in an epoch.
|
||||
|
||||
|
||||
Args:
|
||||
dataset(Dataset): this could be a :code:`paddle.io.Dataset`
|
||||
implement or other python object which implemented
|
||||
:code:`__len__` for BatchSampler to get indices as the
|
||||
range of :attr:`dataset` length. Default None.
|
||||
indices (list|tuple): a substitution parameter for
|
||||
:attr:`dataset` either :attr:`dataset` or
|
||||
:attr:`indices` should be set, give the whole
|
||||
indices to sampler from directly. Default None.
|
||||
shuffle(bool): whether to shuffle indices order before genrating
|
||||
batch indices. Default False.
|
||||
batch_size(int): sample indice number in a mini-batch indices.
|
||||
drop_last(bool): whether drop the last incomplete batch dataset size
|
||||
is not divisible by the batch size. Default False
|
||||
|
||||
Returns:
|
||||
BatchSampler: an iterable object for indices iterating
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.io import BatchSampler, Dataset
|
||||
|
||||
# init with indices
|
||||
bs = BatchSampler(indices=list(range(100)),
|
||||
shuffle=True,
|
||||
batch_size=8,
|
||||
drop_last=True)
|
||||
|
||||
for batch_indices in bs:
|
||||
print(batch_indices)
|
||||
|
||||
# init with dataset
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples):
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image = np.random.random([784]).astype('float32')
|
||||
label = np.random.randint(0, 9, (1, )).astype('int64')
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
bs = BatchSampler(dataset=RandomDataset(100),
|
||||
shuffle=False,
|
||||
batch_size=16,
|
||||
drop_last=False)
|
||||
|
||||
for batch_indices in bs:
|
||||
print(batch_indices)
|
||||
|
||||
see `paddle.io.DataLoader`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset=None,
|
||||
indices=None,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
drop_last=False):
|
||||
if dataset is None:
|
||||
assert indices is not None, \
|
||||
"either dataset or indices should be set"
|
||||
assert isinstance(indices, list) or isinstance(indices, tuple), \
|
||||
"indices should be a list or tuple, but got {}".format(type(indices))
|
||||
self.indices = indices
|
||||
else:
|
||||
assert isinstance(dataset, Dataset), \
|
||||
"dataset should be an instance of paddle.io.Dataset"
|
||||
assert indices is None, \
|
||||
"should not set both dataset and indices"
|
||||
self.indices = list(range(len(dataset)))
|
||||
|
||||
assert isinstance(batch_size, int) and batch_size > 0, \
|
||||
"batch_size should be a positive integer, but got {}".format(batch_size)
|
||||
self.batch_size = batch_size
|
||||
assert isinstance(shuffle, bool), \
|
||||
"shuffle should be a boolean value, but got {}".format(type(shuffle))
|
||||
self.shuffle = shuffle
|
||||
assert isinstance(drop_last, bool), \
|
||||
"drop_last should be a boolean value, but got {}".format(type(drop_last))
|
||||
self.drop_last = drop_last
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.indices)
|
||||
_iter = iter(self.indices)
|
||||
|
||||
batch_indices = []
|
||||
for idx in _iter:
|
||||
batch_indices.append(idx)
|
||||
if len(batch_indices) == self.batch_size:
|
||||
yield batch_indices
|
||||
batch_indices = []
|
||||
if not self.drop_last and len(batch_indices) > 0:
|
||||
yield batch_indices
|
||||
|
||||
def __len__(self):
|
||||
num_samples = len(self.indices)
|
||||
num_samples += int(not self.drop_last) * (self.batch_size - 1)
|
||||
return num_samples // self.batch_size
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.dataset.common
|
||||
|
||||
__all__ = ["Dataset"]
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""
|
||||
An abstract class to encapsulates methods and behaviors of datasets.
|
||||
|
||||
All datasets in map-style(dataset samples can be get by a given key)
|
||||
should be a subclass of `paddle.io.Dataset`. All subclasses should
|
||||
implement following methods:
|
||||
|
||||
:code:`__getitem__`: get sample from dataset with a given index. This
|
||||
method is required by reading dataset sample in :code:`paddle.io.DataLoader`.
|
||||
|
||||
:code:`__len__`: return dataset sample number. This method is required
|
||||
by some implements of :code:`paddle.io.BatchSampler`
|
||||
|
||||
see :code:`paddle.io.DataLoader`.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset
|
||||
|
||||
# define a random dataset
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples):
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image = np.random.random([784]).astype('float32')
|
||||
label = np.random.randint(0, 9, (1, )).astype('int64')
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
dataset = RandomDataset(10)
|
||||
for i in range(len(dataset)):
|
||||
print(dataset[i])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, idx):
|
||||
raise NotImplementedError("'{}' not implement in class "\
|
||||
"{}".format('__getitem__', self.__class__.__name__))
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError("'{}' not implement in class "\
|
||||
"{}".format('__len__', self.__class__.__name__))
|
@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import six
|
||||
import sys
|
||||
import signal
|
||||
import atexit
|
||||
|
||||
from . import core
|
||||
|
||||
# NOTE: queue has a different name in python2 and python3
|
||||
if six.PY2:
|
||||
import Queue as queue
|
||||
else:
|
||||
import queue
|
||||
|
||||
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
|
||||
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
|
||||
# from the child process will automatically clear the memory-mapped file.
|
||||
multiprocess_queue_set = set()
|
||||
|
||||
|
||||
def _clear_multiprocess_queue_set():
|
||||
global multiprocess_queue_set
|
||||
for data_queue in multiprocess_queue_set:
|
||||
while True:
|
||||
try:
|
||||
data_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
|
||||
# NOTE: main process clear function at exit
|
||||
def _cleanup():
|
||||
# NOTE: inter-process Queue shared memory objects clear function
|
||||
_clear_multiprocess_queue_set()
|
||||
# NOTE: main process memory map files clear funciton
|
||||
core._cleanup_mmap_fds()
|
||||
|
||||
|
||||
# NOTE: for child process clear function at exit
|
||||
def _cleanup_mmap():
|
||||
# clear memory map files in child process
|
||||
core._cleanup_mmap_fds()
|
||||
|
||||
|
||||
# NOTE used for register a function to be executed at interpreter exit.
|
||||
class CleanupFuncRegistrar():
|
||||
# Record the cleanup functions that have been executed
|
||||
_executed_func_set = set()
|
||||
# Record the cleanup functions that have been registered
|
||||
_registered_func_set = set()
|
||||
|
||||
@classmethod
|
||||
def register(cls, function, signals=[]):
|
||||
def _func_exectuor():
|
||||
if function not in cls._executed_func_set:
|
||||
try:
|
||||
function()
|
||||
finally:
|
||||
cls._executed_func_set.add(function)
|
||||
|
||||
def _func_register(function):
|
||||
if not callable(function):
|
||||
raise TypeError("%s is not callable object." % (function))
|
||||
# check function object whether hash-able
|
||||
set([function])
|
||||
if function not in cls._registered_func_set:
|
||||
atexit.register(_func_exectuor)
|
||||
cls._registered_func_set.add(function)
|
||||
|
||||
def _signal_handler(signum=None, frame=None):
|
||||
_func_exectuor()
|
||||
if signum is not None:
|
||||
if signum == signal.SIGINT:
|
||||
raise KeyboardInterrupt
|
||||
sys.exit(signum)
|
||||
|
||||
def _signal_register(signals):
|
||||
signals = set(signals)
|
||||
for sig in signals:
|
||||
orig_handler = signal.signal(sig, _signal_handler)
|
||||
if orig_handler not in (signal.SIG_DFL, signal.SIG_IGN):
|
||||
if (sig == signal.SIGINT and
|
||||
orig_handler is signal.default_int_handler):
|
||||
continue
|
||||
if orig_handler not in cls._registered_func_set:
|
||||
atexit.register(orig_handler)
|
||||
cls._registered_func_set.add(orig_handler)
|
||||
|
||||
# deal with signals
|
||||
_signal_register(signals)
|
||||
# deal with function
|
||||
_func_register(function)
|
||||
|
||||
|
||||
# NOTE: [ mmap files clear ] When the main process exits unexpectedly, the remaining
|
||||
# shared memory objects in the inter-process Queue and the main process (mostly in the
|
||||
# BlockingQueue) may not be completely released, resulting in the corresponding
|
||||
# memory-mapped file remaining on the disk (/dev/shm), so register this function
|
||||
# to clean up shared memory objects in these two queues before the python interpreter exits.
|
||||
# NOTE: Currently multi-process DataLoader only supports Linux platform
|
||||
if not (sys.platform == 'darwin' or sys.platform == 'win32'):
|
||||
CleanupFuncRegistrar.register(_cleanup)
|
||||
|
||||
# ------------ SIGCHLD handler setting --------------
|
||||
_SIGCHLD_handler_set = False
|
||||
|
||||
|
||||
def _set_SIGCHLD_handler():
|
||||
global _SIGCHLD_handler_set
|
||||
if _SIGCHLD_handler_set:
|
||||
return
|
||||
|
||||
current_handler = signal.getsignal(signal.SIGCHLD)
|
||||
if not callable(current_handler):
|
||||
current_handler = None
|
||||
|
||||
def __handler__(signum, frame):
|
||||
# NOTE: Here the signum is SIGCHLD, when the child process exits,
|
||||
# this handler will be called whenever the child process exits
|
||||
# normally or abnormally.
|
||||
core._throw_error_if_process_failed()
|
||||
if current_handler is not None:
|
||||
current_handler(signum, frame)
|
||||
|
||||
signal.signal(signal.SIGCHLD, __handler__)
|
||||
_SIGCHLD_handler_set = True
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.io import BatchSampler, Dataset
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, sample_num, class_num):
|
||||
self.sample_num = sample_num
|
||||
self.class_num = class_num
|
||||
|
||||
def __getitem__(self, idx):
|
||||
np.random.seed(idx)
|
||||
image = np.random.random([IMAGE_SIZE]).astype('float32')
|
||||
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return self.sample_num
|
||||
|
||||
|
||||
class TestBatchSampler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.num_samples = 1000
|
||||
self.num_classes = 10
|
||||
self.batch_size = 32
|
||||
self.shuffle = False
|
||||
self.drop_last = False
|
||||
|
||||
def init_batch_sampler(self):
|
||||
dataset = RandomDataset(self.num_samples, self.num_classes)
|
||||
bs = BatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=self.shuffle,
|
||||
drop_last=self.drop_last)
|
||||
return bs
|
||||
|
||||
def test_main(self):
|
||||
bs = self.init_batch_sampler()
|
||||
# length check
|
||||
bs_len = (self.num_samples + int(not self.drop_last) \
|
||||
* (self.batch_size - 1)) // self.batch_size
|
||||
self.assertTrue(bs_len == len(bs))
|
||||
|
||||
# output indices check
|
||||
if not self.shuffle:
|
||||
index = 0
|
||||
for indices in bs:
|
||||
for idx in indices:
|
||||
self.assertTrue(index == idx)
|
||||
index += 1
|
||||
|
||||
|
||||
class TestBatchSamplerDropLast(TestBatchSampler):
|
||||
def setUp(self):
|
||||
self.num_samples = 1000
|
||||
self.num_classes = 10
|
||||
self.batch_size = 32
|
||||
self.shuffle = False
|
||||
self.drop_last = True
|
||||
|
||||
|
||||
class TestBatchSamplerShuffle(TestBatchSampler):
|
||||
def setUp(self):
|
||||
self.num_samples = 1000
|
||||
self.num_classes = 10
|
||||
self.batch_size = 32
|
||||
self.shuffle = True
|
||||
self.drop_last = True
|
||||
|
||||
|
||||
class TestBatchSamplerWithIndices(TestBatchSampler):
|
||||
def init_batch_sampler(self):
|
||||
bs = BatchSampler(
|
||||
indices=list(range(self.num_samples)),
|
||||
batch_size=self.batch_size,
|
||||
drop_last=self.drop_last)
|
||||
return bs
|
||||
|
||||
|
||||
class TestBatchSamplerWithIndicesAndDataSource(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.num_samples = 1000
|
||||
self.num_classes = 10
|
||||
self.batch_size = 32
|
||||
self.shuffle = False
|
||||
self.drop_last = True
|
||||
|
||||
def test_main(self):
|
||||
try:
|
||||
dataset = RandomDataset(self.num_samples, self.num_classes)
|
||||
bs = BatchSampler(
|
||||
dataset=dataset,
|
||||
indices=list(range(self.num_samples)),
|
||||
batch_size=self.batch_size,
|
||||
drop_last=self.drop_last)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.io import *
|
||||
|
||||
|
||||
class TestDatasetAbstract(unittest.TestCase):
|
||||
def test_main(self):
|
||||
dataset = Dataset()
|
||||
try:
|
||||
d = dataset[0]
|
||||
self.assertTrue(False)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
try:
|
||||
l = len(dataset)
|
||||
self.assertTrue(False)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import sys
|
||||
import six
|
||||
import time
|
||||
import unittest
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.io import Dataset, BatchSampler, DataLoader
|
||||
from paddle.fluid.dygraph.nn import Linear
|
||||
from paddle.fluid.dygraph.base import to_variable
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, sample_num):
|
||||
self.sample_num = sample_num
|
||||
|
||||
def __getitem__(self, idx):
|
||||
np.random.seed(idx)
|
||||
image = np.random.random([784]).astype('float32')
|
||||
label = np.random.randint(0, 9, (1, )).astype('int64')
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return self.sample_num
|
||||
|
||||
|
||||
class TestDataLoaderAssert(unittest.TestCase):
|
||||
def test_main(self):
|
||||
place = fluid.cpu_places()[0]
|
||||
with fluid.dygraph.guard(place):
|
||||
dataset = RandomDataset(100)
|
||||
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
|
||||
|
||||
# dataset is not instance of Dataset
|
||||
try:
|
||||
loader = DataLoader(dataset=batch_sampler, places=place)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# places is None
|
||||
try:
|
||||
loader = DataLoader(dataset=dataset, places=None)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# num_workers < 0
|
||||
try:
|
||||
loader = DataLoader(
|
||||
dataset=dataset, places=place, num_workers=-1)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# timeout < 0
|
||||
try:
|
||||
loader = DataLoader(dataset=dataset, places=place, timeout=-1)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# batch_sampler is not instance of BatchSampler
|
||||
try:
|
||||
loader = DataLoader(
|
||||
dataset=dataset, places=place, batch_sampler=dataset)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# set batch_sampler and shuffle/batch_size/drop_last
|
||||
try:
|
||||
loader = DataLoader(
|
||||
dataset=dataset,
|
||||
places=place,
|
||||
batch_sampler=batch_sampler,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
# set batch_sampler correctly
|
||||
try:
|
||||
loader = DataLoader(
|
||||
dataset=dataset, places=place, batch_sampler=batch_sampler)
|
||||
self.assertTrue(True)
|
||||
except AssertionError:
|
||||
self.assertTrue(False)
|
||||
|
||||
|
||||
# CI Converage cannot record stub in subprocess,
|
||||
# HACK a _worker_loop in main process call here
|
||||
class TestDataLoaderWorkerLoop(unittest.TestCase):
|
||||
def run_without_worker_done(self, use_shared_memory=True):
|
||||
try:
|
||||
place = fluid.cpu_places()[0]
|
||||
with fluid.dygraph.guard(place):
|
||||
dataset = RandomDataset(800)
|
||||
|
||||
# test init_fn
|
||||
def _init_fn(worker_id):
|
||||
pass
|
||||
|
||||
# test collate_fn
|
||||
def _collate_fn(sample_list):
|
||||
return [
|
||||
np.stack(
|
||||
s, axis=0) for s in list(zip(*sample_list))
|
||||
]
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
num_workers=1,
|
||||
places=place,
|
||||
use_shared_memory=use_shared_memory)
|
||||
assert loader.num_workers > 0, \
|
||||
"go to AssertionError and pass in Mac and Windows"
|
||||
loader = iter(loader)
|
||||
print("loader length", len(loader))
|
||||
indices_queue = multiprocessing.Queue()
|
||||
for i in range(10):
|
||||
indices_queue.put([i, i + 10])
|
||||
indices_queue.put(None)
|
||||
loader._worker_loop(
|
||||
loader._dataset, indices_queue, loader._data_queue,
|
||||
loader._workers_done_event, _collate_fn, _init_fn, 0)
|
||||
self.assertTrue(False)
|
||||
except AssertionError:
|
||||
pass
|
||||
except Exception:
|
||||
self.assertTrue(False)
|
||||
|
||||
def run_with_worker_done(self, use_shared_memory=True):
|
||||
try:
|
||||
place = fluid.cpu_places()[0]
|
||||
with fluid.dygraph.guard(place):
|
||||
dataset = RandomDataset(800)
|
||||
|
||||
# test init_fn
|
||||
def _init_fn(worker_id):
|
||||
pass
|
||||
|
||||
# test collate_fn
|
||||
def _collate_fn(sample_list):
|
||||
return [
|
||||
np.stack(
|
||||
s, axis=0) for s in list(zip(*sample_list))
|
||||
]
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
num_workers=1,
|
||||
places=place,
|
||||
use_shared_memory=use_shared_memory)
|
||||
assert loader.num_workers > 0, \
|
||||
"go to AssertionError and pass in Mac and Windows"
|
||||
loader = iter(loader)
|
||||
print("loader length", len(loader))
|
||||
indices_queue = multiprocessing.Queue()
|
||||
for i in range(10):
|
||||
indices_queue.put([i, i + 10])
|
||||
indices_queue.put(None)
|
||||
loader._workers_done_event.set()
|
||||
loader._worker_loop(
|
||||
loader._dataset, indices_queue, loader._data_queue,
|
||||
loader._workers_done_event, _collate_fn, _init_fn, 0)
|
||||
self.assertTrue(True)
|
||||
except AssertionError:
|
||||
pass
|
||||
except Exception:
|
||||
self.assertTrue(False)
|
||||
|
||||
def test_main(self):
|
||||
for use_shared_memory in [True, False]:
|
||||
self.run_without_worker_done(use_shared_memory)
|
||||
self.run_with_worker_done(use_shared_memory)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue