|
|
|
@ -15,7 +15,7 @@
|
|
|
|
|
__all__ = ['batch']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch(reader, batch_size):
|
|
|
|
|
def batch(reader, batch_size, drop_last=False):
|
|
|
|
|
"""
|
|
|
|
|
Create a batched reader.
|
|
|
|
|
|
|
|
|
@ -23,6 +23,8 @@ def batch(reader, batch_size):
|
|
|
|
|
:type reader: callable
|
|
|
|
|
:param batch_size: size of each mini-batch
|
|
|
|
|
:type batch_size: int
|
|
|
|
|
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
|
|
|
|
|
:type drop_last: bool
|
|
|
|
|
:return: the batched reader.
|
|
|
|
|
:rtype: callable
|
|
|
|
|
"""
|
|
|
|
@ -35,7 +37,7 @@ def batch(reader, batch_size):
|
|
|
|
|
if len(b) == batch_size:
|
|
|
|
|
yield b
|
|
|
|
|
b = []
|
|
|
|
|
if b:
|
|
|
|
|
if drop_last == False and len(b) != 0:
|
|
|
|
|
yield b
|
|
|
|
|
|
|
|
|
|
return batch_reader
|
|
|
|
|