parent
6ab6c35621
commit
f846362927
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from decorator import *
|
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['buffered']
|
||||
|
||||
from Queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
|
||||
def buffered(reader, size):
|
||||
"""Creates a buffered data reader.
|
||||
|
||||
The buffered data reader will read and save data entries into a buffer.
|
||||
Reading from the buffered data reader will proceed as long as the buffer
|
||||
is not empty.
|
||||
|
||||
Args:
|
||||
reader: the data reader to read from.
|
||||
size: max buffer size.
|
||||
|
||||
Returns:
|
||||
The buffered data reader.
|
||||
"""
|
||||
|
||||
class EndSignal():
|
||||
pass
|
||||
|
||||
end = EndSignal()
|
||||
|
||||
def read_worker(r, q):
|
||||
for d in r:
|
||||
q.put(d)
|
||||
q.put(end)
|
||||
|
||||
def create_reader():
|
||||
r = reader()
|
||||
q = Queue(maxsize=size)
|
||||
t = Thread(
|
||||
target=read_worker, args=(
|
||||
r,
|
||||
q, ))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
e = q.get()
|
||||
while e != end:
|
||||
yield e
|
||||
e = q.get()
|
||||
|
||||
return create_reader
|
@ -0,0 +1,4 @@
|
||||
add_test(NAME reader_decorator_test
|
||||
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
|
||||
${PYTHON_EXECUTABLE} ${PROJ_ROOT}/python/paddle/reader/decorator_test.py
|
||||
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
|
@ -0,0 +1,50 @@
|
||||
# Copyright PaddlePaddle contributors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import paddle.reader
|
||||
import time
|
||||
|
||||
|
||||
def reader_10(dur):
|
||||
for i in range(10):
|
||||
time.sleep(dur)
|
||||
yield i
|
||||
|
||||
|
||||
class TestBuffered(unittest.TestCase):
|
||||
def test_read(self):
|
||||
for size in range(20):
|
||||
b = paddle.reader.buffered(lambda: reader_10(0), size)
|
||||
c = 0
|
||||
for i in b():
|
||||
self.assertEqual(i, c)
|
||||
c += 1
|
||||
self.assertEqual(c, 10)
|
||||
|
||||
def test_buffering(self):
|
||||
# read have 30ms delay.
|
||||
b = paddle.reader.buffered(lambda: reader_10(0.03), 10)
|
||||
last_time = time.time()
|
||||
for idx, i in enumerate(b()):
|
||||
elapsed_time = time.time() - last_time
|
||||
if i == 0:
|
||||
time.sleep(0.3)
|
||||
else:
|
||||
# read time should be short, meaning already buffered.
|
||||
self.assertLess(elapsed_time, 0.01)
|
||||
last_time = time.time()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue