|
|
|
@ -34,16 +34,6 @@ def sample_generator_creator(batch_size, batch_num):
|
|
|
|
|
return __reader__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_generator_creator(batch_size, batch_num):
|
|
|
|
|
def __reader__():
|
|
|
|
|
for _ in range(batch_num):
|
|
|
|
|
batch_image, batch_label = get_random_images_and_labels(
|
|
|
|
|
[batch_size, 784], [batch_size, 1])
|
|
|
|
|
yield batch_image, batch_label
|
|
|
|
|
|
|
|
|
|
return __reader__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.batch_size = 8
|
|
|
|
@ -51,7 +41,7 @@ class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
self.epoch_num = 1
|
|
|
|
|
self.capacity = 5
|
|
|
|
|
|
|
|
|
|
def test_single_process_reader(self):
|
|
|
|
|
def test_single_process_loader(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
loader = fluid.io.DataLoader.from_generator(
|
|
|
|
|
capacity=self.capacity, iterable=False, use_multiprocess=False)
|
|
|
|
@ -66,7 +56,7 @@ class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
self.assertEqual(label.shape, [self.batch_size, 1])
|
|
|
|
|
self.assertEqual(relu.shape, [self.batch_size, 784])
|
|
|
|
|
|
|
|
|
|
def test_sample_genarator(self):
|
|
|
|
|
def test_multi_process_loader(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
loader = fluid.io.DataLoader.from_generator(
|
|
|
|
|
capacity=self.capacity, use_multiprocess=True)
|
|
|
|
@ -81,20 +71,6 @@ class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
self.assertEqual(label.shape, [self.batch_size, 1])
|
|
|
|
|
self.assertEqual(relu.shape, [self.batch_size, 784])
|
|
|
|
|
|
|
|
|
|
def test_batch_genarator(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
loader = fluid.io.DataLoader.from_generator(
|
|
|
|
|
capacity=self.capacity, use_multiprocess=True)
|
|
|
|
|
loader.set_batch_generator(
|
|
|
|
|
batch_generator_creator(self.batch_size, self.batch_num),
|
|
|
|
|
places=fluid.CPUPlace())
|
|
|
|
|
for _ in range(self.epoch_num):
|
|
|
|
|
for image, label in loader():
|
|
|
|
|
relu = fluid.layers.relu(image)
|
|
|
|
|
self.assertEqual(image.shape, [self.batch_size, 784])
|
|
|
|
|
self.assertEqual(label.shape, [self.batch_size, 1])
|
|
|
|
|
self.assertEqual(relu.shape, [self.batch_size, 784])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|