|
|
@ -43,9 +43,8 @@ class TestMultipleReader(unittest.TestCase):
|
|
|
|
filename='./mnist.recordio',
|
|
|
|
filename='./mnist.recordio',
|
|
|
|
shapes=[(-1, 784), (-1, 1)],
|
|
|
|
shapes=[(-1, 784), (-1, 1)],
|
|
|
|
lod_levels=[0, 0],
|
|
|
|
lod_levels=[0, 0],
|
|
|
|
dtypes=['float32', 'int64'])
|
|
|
|
dtypes=['float32', 'int64'],
|
|
|
|
data_file = fluid.layers.io.multi_pass(
|
|
|
|
pass_num=self.pass_num)
|
|
|
|
reader=data_file, pass_num=self.pass_num)
|
|
|
|
|
|
|
|
img, label = fluid.layers.read_file(data_file)
|
|
|
|
img, label = fluid.layers.read_file(data_file)
|
|
|
|
|
|
|
|
|
|
|
|
if fluid.core.is_compiled_with_cuda():
|
|
|
|
if fluid.core.is_compiled_with_cuda():
|
|
|
@ -65,5 +64,4 @@ class TestMultipleReader(unittest.TestCase):
|
|
|
|
break
|
|
|
|
break
|
|
|
|
batch_count += 1
|
|
|
|
batch_count += 1
|
|
|
|
self.assertLessEqual(img_val.shape[0], self.batch_size)
|
|
|
|
self.assertLessEqual(img_val.shape[0], self.batch_size)
|
|
|
|
data_file.reset()
|
|
|
|
|
|
|
|
self.assertEqual(batch_count, self.num_batch * self.pass_num)
|
|
|
|
self.assertEqual(batch_count, self.num_batch * self.pass_num)
|
|
|
|