You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
326 lines
8.9 KiB
326 lines
8.9 KiB
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
__all__ = [
|
|
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
|
|
'ComposeNotAligned', 'firstn', 'xmap_readers'
|
|
]
|
|
|
|
import itertools
|
|
import random
|
|
from Queue import Queue
|
|
from threading import Thread
|
|
|
|
|
|
def map_readers(func, *readers):
|
|
"""
|
|
Creates a data reader that outputs return value of function using
|
|
output of each data readers as arguments.
|
|
|
|
:param func: function to use. The type of func should be (Sample) => Sample
|
|
:type: callable
|
|
:param readers: readers whose outputs will be used as arguments of func.
|
|
:return: the created data reader.
|
|
:rtype: callable
|
|
"""
|
|
|
|
def reader():
|
|
rs = []
|
|
for r in readers:
|
|
rs.append(r())
|
|
for e in itertools.imap(func, *rs):
|
|
yield e
|
|
|
|
return reader
|
|
|
|
|
|
def shuffle(reader, buf_size):
|
|
"""
|
|
Creates a data reader whose data output is shuffled.
|
|
|
|
Output from the iterator that created by original reader will be
|
|
buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
|
|
is determined by argument buf_size.
|
|
|
|
:param reader: the original reader whose output will be shuffled.
|
|
:type reader: callable
|
|
:param buf_size: shuffle buffer size.
|
|
:type buf_size: int
|
|
|
|
:return: the new reader whose output is shuffled.
|
|
:rtype: callable
|
|
"""
|
|
|
|
def data_reader():
|
|
buf = []
|
|
for e in reader():
|
|
buf.append(e)
|
|
if len(buf) >= buf_size:
|
|
random.shuffle(buf)
|
|
for b in buf:
|
|
yield b
|
|
buf = []
|
|
|
|
if len(buf) > 0:
|
|
random.shuffle(buf)
|
|
for b in buf:
|
|
yield b
|
|
|
|
return data_reader
|
|
|
|
|
|
def chain(*readers):
|
|
"""
|
|
Creates a data reader whose output is the outputs of input data
|
|
readers chained together.
|
|
|
|
If input readers output following data entries:
|
|
[0, 0, 0]
|
|
[1, 1, 1]
|
|
[2, 2, 2]
|
|
The chained reader will output:
|
|
[0, 0, 0, 1, 1, 1, 2, 2, 2]
|
|
|
|
:param readers: input readers.
|
|
:return: the new data reader.
|
|
:rtype: callable
|
|
"""
|
|
|
|
def reader():
|
|
rs = []
|
|
for r in readers:
|
|
rs.append(r())
|
|
|
|
for e in itertools.chain(*rs):
|
|
yield e
|
|
|
|
return reader
|
|
|
|
|
|
class ComposeNotAligned(ValueError):
|
|
pass
|
|
|
|
|
|
def compose(*readers, **kwargs):
|
|
"""
|
|
Creates a data reader whose output is the combination of input readers.
|
|
|
|
If input readers output following data entries:
|
|
(1, 2) 3 (4, 5)
|
|
The composed reader will output:
|
|
(1, 2, 3, 4, 5)
|
|
|
|
:param readers: readers that will be composed together.
|
|
:param check_alignment: if True, will check if input readers are aligned
|
|
correctly. If False, will not check alignment and trailing outputs
|
|
will be discarded. Defaults to True.
|
|
:type check_alignment: bool
|
|
|
|
:return: the new data reader.
|
|
|
|
:raises ComposeNotAligned: outputs of readers are not aligned.
|
|
Will not raise when check_alignment is set to False.
|
|
"""
|
|
check_alignment = kwargs.pop('check_alignment', True)
|
|
|
|
def make_tuple(x):
|
|
if isinstance(x, tuple):
|
|
return x
|
|
else:
|
|
return (x, )
|
|
|
|
def reader():
|
|
rs = []
|
|
for r in readers:
|
|
rs.append(r())
|
|
if not check_alignment:
|
|
for outputs in itertools.izip(*rs):
|
|
yield sum(map(make_tuple, outputs), ())
|
|
else:
|
|
for outputs in itertools.izip_longest(*rs):
|
|
for o in outputs:
|
|
if o is None:
|
|
# None will be not be present if compose is aligned
|
|
raise ComposeNotAligned(
|
|
"outputs of readers are not aligned.")
|
|
yield sum(map(make_tuple, outputs), ())
|
|
|
|
return reader
|
|
|
|
|
|
def buffered(reader, size):
|
|
"""
|
|
Creates a buffered data reader.
|
|
|
|
The buffered data reader will read and save data entries into a
|
|
buffer. Reading from the buffered data reader will proceed as long
|
|
as the buffer is not empty.
|
|
|
|
:param reader: the data reader to read from.
|
|
:type reader: callable
|
|
:param size: max buffer size.
|
|
:type size: int
|
|
|
|
:returns: the buffered data reader.
|
|
"""
|
|
|
|
class EndSignal():
|
|
pass
|
|
|
|
end = EndSignal()
|
|
|
|
def read_worker(r, q):
|
|
for d in r:
|
|
q.put(d)
|
|
q.put(end)
|
|
|
|
def data_reader():
|
|
r = reader()
|
|
q = Queue(maxsize=size)
|
|
t = Thread(
|
|
target=read_worker, args=(
|
|
r,
|
|
q, ))
|
|
t.daemon = True
|
|
t.start()
|
|
e = q.get()
|
|
while e != end:
|
|
yield e
|
|
e = q.get()
|
|
|
|
return data_reader
|
|
|
|
|
|
def firstn(reader, n):
|
|
"""
|
|
Limit the max number of samples that reader could return.
|
|
|
|
:param reader: the data reader to read from.
|
|
:type reader: callable
|
|
:param n: the max number of samples that return.
|
|
:type n: int
|
|
:return: the decorated reader.
|
|
:rtype: callable
|
|
"""
|
|
|
|
# TODO(yuyang18): Check if just drop the reader, could clean the opened
|
|
# resource or not?
|
|
|
|
def firstn_reader():
|
|
for i, item in enumerate(reader()):
|
|
if i == n:
|
|
break
|
|
yield item
|
|
|
|
return firstn_reader
|
|
|
|
|
|
class XmapEndSignal():
|
|
pass
|
|
|
|
|
|
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
|
|
"""
|
|
Use multiprocess to map samples from reader by a mapper defined by user.
|
|
And this function contains a buffered decorator.
|
|
:param mapper: a function to map sample.
|
|
:type mapper: callable
|
|
:param reader: the data reader to read from
|
|
:type reader: callable
|
|
:param process_num: process number to handle original sample
|
|
:type process_num: int
|
|
:param buffer_size: max buffer size
|
|
:type buffer_size: int
|
|
:param order: keep the order of reader
|
|
:type order: bool
|
|
:return: the decarated reader
|
|
:rtype: callable
|
|
"""
|
|
end = XmapEndSignal()
|
|
|
|
# define a worker to read samples from reader to in_queue
|
|
def read_worker(reader, in_queue):
|
|
for i in reader():
|
|
in_queue.put(i)
|
|
in_queue.put(end)
|
|
|
|
# define a worker to read samples from reader to in_queue with order flag
|
|
def order_read_worker(reader, in_queue):
|
|
in_order = 0
|
|
for i in reader():
|
|
in_queue.put((in_order, i))
|
|
in_order += 1
|
|
in_queue.put(end)
|
|
|
|
# define a worker to handle samples from in_queue by mapper
|
|
# and put mapped samples into out_queue
|
|
def handle_worker(in_queue, out_queue, mapper):
|
|
sample = in_queue.get()
|
|
while not isinstance(sample, XmapEndSignal):
|
|
r = mapper(sample)
|
|
out_queue.put(r)
|
|
sample = in_queue.get()
|
|
in_queue.put(end)
|
|
out_queue.put(end)
|
|
|
|
# define a worker to handle samples from in_queue by mapper
|
|
# and put mapped samples into out_queue by order
|
|
def order_handle_worker(in_queue, out_queue, mapper, out_order):
|
|
ins = in_queue.get()
|
|
while not isinstance(ins, XmapEndSignal):
|
|
order, sample = ins
|
|
r = mapper(sample)
|
|
while order != out_order[0]:
|
|
pass
|
|
out_queue.put(r)
|
|
out_order[0] += 1
|
|
ins = in_queue.get()
|
|
in_queue.put(end)
|
|
out_queue.put(end)
|
|
|
|
def xreader():
|
|
in_queue = Queue(buffer_size)
|
|
out_queue = Queue(buffer_size)
|
|
out_order = [0]
|
|
# start a read worker in a thread
|
|
target = order_read_worker if order else read_worker
|
|
t = Thread(target=target, args=(reader, in_queue))
|
|
t.daemon = True
|
|
t.start()
|
|
# start several handle_workers
|
|
target = order_handle_worker if order else handle_worker
|
|
args = (in_queue, out_queue, mapper, out_order) if order else (
|
|
in_queue, out_queue, mapper)
|
|
workers = []
|
|
for i in xrange(process_num):
|
|
worker = Thread(target=target, args=args)
|
|
worker.daemon = True
|
|
workers.append(worker)
|
|
for w in workers:
|
|
w.start()
|
|
|
|
sample = out_queue.get()
|
|
while not isinstance(sample, XmapEndSignal):
|
|
yield sample
|
|
sample = out_queue.get()
|
|
finish = 1
|
|
while finish < process_num:
|
|
sample = out_queue.get()
|
|
if isinstance(sample, XmapEndSignal):
|
|
finish += 1
|
|
else:
|
|
yield sample
|
|
|
|
return xreader
|