Just train two batch

guochaorong-patch-1
yuyang18 7 years ago
parent bcb1516d11
commit e8eb81ca28
No known key found for this signature in database
GPG Key ID: 6DFF29878217BE5F

@ -520,7 +520,7 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None, name=None):
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op( startup_blk.append_op(
type='create_py_reader', type='create_py_reader',
inputs={'blocking_queue': queue_name}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]}, outputs={'Out': [startup_var]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,

@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.dataset.mnist as mnist import paddle.dataset.mnist as mnist
import paddle import paddle
import paddle.v2
import threading import threading
import numpy import numpy
@ -91,7 +92,8 @@ def main():
for epoch_id in xrange(10): for epoch_id in xrange(10):
train_data_thread = pipe_reader_to_queue( train_data_thread = pipe_reader_to_queue(
paddle.batch(mnist.train(), 32), train_queue) paddle.batch(paddle.v2.reader.firstn(mnist.train(), 32), 64),
train_queue)
try: try:
while True: while True:
print 'train_loss', numpy.array( print 'train_loss', numpy.array(

Loading…
Cancel
Save