|
|
|
@ -41,6 +41,14 @@ class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
self.epoch_num = 1
|
|
|
|
|
self.capacity = 5
|
|
|
|
|
|
|
|
|
|
def iter_loader_data(self, loader):
|
|
|
|
|
for _ in range(self.epoch_num):
|
|
|
|
|
for image, label in loader():
|
|
|
|
|
relu = fluid.layers.relu(image)
|
|
|
|
|
self.assertEqual(image.shape, [self.batch_size, 784])
|
|
|
|
|
self.assertEqual(label.shape, [self.batch_size, 1])
|
|
|
|
|
self.assertEqual(relu.shape, [self.batch_size, 784])
|
|
|
|
|
|
|
|
|
|
def test_single_process_loader(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
loader = fluid.io.DataLoader.from_generator(
|
|
|
|
@ -49,12 +57,7 @@ class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
sample_generator_creator(self.batch_size, self.batch_num),
|
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
|
places=fluid.CPUPlace())
|
|
|
|
|
for _ in range(self.epoch_num):
|
|
|
|
|
for image, label in loader():
|
|
|
|
|
relu = fluid.layers.relu(image)
|
|
|
|
|
self.assertEqual(image.shape, [self.batch_size, 784])
|
|
|
|
|
self.assertEqual(label.shape, [self.batch_size, 1])
|
|
|
|
|
self.assertEqual(relu.shape, [self.batch_size, 784])
|
|
|
|
|
self.iter_loader_data(loader)
|
|
|
|
|
|
|
|
|
|
def test_multi_process_loader(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
@ -64,12 +67,15 @@ class TestDygraphDataLoader(unittest.TestCase):
|
|
|
|
|
sample_generator_creator(self.batch_size, self.batch_num),
|
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
|
places=fluid.CPUPlace())
|
|
|
|
|
for _ in range(self.epoch_num):
|
|
|
|
|
for image, label in loader():
|
|
|
|
|
relu = fluid.layers.relu(image)
|
|
|
|
|
self.assertEqual(image.shape, [self.batch_size, 784])
|
|
|
|
|
self.assertEqual(label.shape, [self.batch_size, 1])
|
|
|
|
|
self.assertEqual(relu.shape, [self.batch_size, 784])
|
|
|
|
|
self.iter_loader_data(loader)
|
|
|
|
|
|
|
|
|
|
def test_generator_no_places(self):
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
loader = fluid.io.DataLoader.from_generator(capacity=self.capacity)
|
|
|
|
|
loader.set_sample_generator(
|
|
|
|
|
sample_generator_creator(self.batch_size, self.batch_num),
|
|
|
|
|
batch_size=self.batch_size)
|
|
|
|
|
self.iter_loader_data(loader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|