|
|
|
|
@ -89,9 +89,11 @@ def train(use_cuda, train_program, parallel, params_dirname):
|
|
|
|
|
cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10),
|
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
|
drop_last=False)
|
|
|
|
|
|
|
|
|
|
# Use only part of the test set data validation program
|
|
|
|
|
test_reader = paddle.batch(
|
|
|
|
|
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False)
|
|
|
|
|
cifar10_small_test_set.test10(BATCH_SIZE),
|
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
|
drop_last=False)
|
|
|
|
|
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, EndStepEvent):
|
|
|
|
|
|