|
|
|
@ -22,9 +22,10 @@ from shutil import copyfile
|
|
|
|
|
|
|
|
|
|
class TestMultipleReader(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.batch_size = 64
|
|
|
|
|
# Convert mnist to recordio file
|
|
|
|
|
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
|
|
|
|
reader = paddle.batch(mnist.train(), batch_size=32)
|
|
|
|
|
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
|
|
|
|
|
feeder = fluid.DataFeeder(
|
|
|
|
|
feed_list=[ # order is image and label
|
|
|
|
|
fluid.layers.data(
|
|
|
|
@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase):
|
|
|
|
|
'./mnist_0.recordio', reader, feeder)
|
|
|
|
|
copyfile('./mnist_0.recordio', './mnist_1.recordio')
|
|
|
|
|
copyfile('./mnist_0.recordio', './mnist_2.recordio')
|
|
|
|
|
print(self.num_batch)
|
|
|
|
|
|
|
|
|
|
def test_multiple_reader(self, thread_num=3):
|
|
|
|
|
def main(self, thread_num):
|
|
|
|
|
file_list = [
|
|
|
|
|
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
|
|
|
|
|
]
|
|
|
|
@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase):
|
|
|
|
|
while not data_files.eof():
|
|
|
|
|
img_val, = exe.run(fetch_list=[img])
|
|
|
|
|
batch_count += 1
|
|
|
|
|
print(batch_count)
|
|
|
|
|
# data_files.reset()
|
|
|
|
|
print("FUCK")
|
|
|
|
|
|
|
|
|
|
self.assertLessEqual(img_val.shape[0], self.batch_size)
|
|
|
|
|
data_files.reset()
|
|
|
|
|
self.assertEqual(batch_count, self.num_batch * 3)
|
|
|
|
|
|
|
|
|
|
def test_main(self):
|
|
|
|
|
self.main(thread_num=3) # thread number equals to file number
|
|
|
|
|
self.main(thread_num=10) # thread number is larger than file number
|
|
|
|
|
self.main(thread_num=2) # thread number is less than file number
|
|
|
|
|