drop the last batch, if the size of last batch is not equal to batch_size

wangkuiyi-patch-1
chengduoZH 7 years ago
parent 376c948e88
commit 164692da9a

@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size):
def batch(reader, batch_size, drop_last=False):
"""
Create a batched reader.
@ -23,6 +23,8 @@ def batch(reader, batch_size):
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
@ -35,7 +37,7 @@ def batch(reader, batch_size):
if len(b) == batch_size:
yield b
b = []
if b:
if drop_last == False and len(b) != 0:
yield b
return batch_reader

@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size):
def batch(reader, batch_size, drop_last=False):
"""
Create a batched reader.
@ -23,6 +23,8 @@ def batch(reader, batch_size):
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
@ -35,7 +37,7 @@ def batch(reader, batch_size):
if len(b) == batch_size:
yield b
b = []
if b:
if drop_last == False and len(b) != 0:
yield b
return batch_reader

Loading…
Cancel
Save