Merge pull request #11402 from chengduoZH/refine_batch_func

Change drop_last defalut value
wangkuiyi-patch-1
chengduo 7 years ago committed by GitHub
commit b4dfd08004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size, drop_last=False):
def batch(reader, batch_size, drop_last=True):
"""
Create a batched reader.

@ -96,10 +96,11 @@ def train(use_cuda, train_program, params_dirname):
train_reader = paddle.batch(
paddle.reader.shuffle(
cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
drop_last=False)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):

@ -73,10 +73,11 @@ def train(use_cuda, train_program, params_dirname):
train_reader = paddle.batch(
paddle.reader.shuffle(
cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
drop_last=False)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):

@ -87,7 +87,9 @@ def train(use_cuda, train_program, params_dirname):
def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
test_reader = paddle.batch(
paddle.dataset.imdb.test(word_dict), batch_size=BATCH_SIZE)
paddle.dataset.imdb.test(word_dict),
batch_size=BATCH_SIZE,
drop_last=False)
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['words', 'label'])
@ -113,7 +115,8 @@ def train(use_cuda, train_program, params_dirname):
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=25000),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE,
drop_last=False)
trainer.train(
num_epochs=1,

@ -56,7 +56,7 @@ BATCH_SIZE = 200
# fix the order of training data
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE)
paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE, drop_last=False)
# train_reader = paddle.batch(
# paddle.reader.shuffle(

@ -15,7 +15,7 @@
__all__ = ['batch']
def batch(reader, batch_size, drop_last=False):
def batch(reader, batch_size, drop_last=True):
"""
Create a batched reader.

Loading…
Cancel
Save