Add more unittests and fix bugs

helinwang-patch-1
fengjiayi 7 years ago
parent f863866471
commit 2532b922dc

@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}

@ -1 +1,4 @@
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio

@ -22,9 +22,10 @@ from shutil import copyfile
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase):
'./mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
print(self.num_batch)
def test_multiple_reader(self, thread_num=3):
def main(self, thread_num):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase):
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
print(batch_count)
# data_files.reset()
print("FUCK")
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number

Loading…
Cancel
Save