|
|
@ -170,5 +170,50 @@ class TestStaticDataLoader(unittest.TestCase):
|
|
|
|
self.assertLess(diff, 1e-2)
|
|
|
|
self.assertLess(diff, 1e-2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestStaticDataLoaderReturnList(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_single_place(self):
|
|
|
|
|
|
|
|
scope = fluid.Scope()
|
|
|
|
|
|
|
|
image = fluid.data(
|
|
|
|
|
|
|
|
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
|
|
|
|
|
|
|
|
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
|
|
|
|
|
|
|
|
with fluid.scope_guard(scope):
|
|
|
|
|
|
|
|
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
|
|
|
|
|
|
|
|
dataloader = DataLoader(
|
|
|
|
|
|
|
|
dataset,
|
|
|
|
|
|
|
|
feed_list=[image, label],
|
|
|
|
|
|
|
|
num_workers=0,
|
|
|
|
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
|
|
|
|
drop_last=True,
|
|
|
|
|
|
|
|
return_list=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for d in dataloader:
|
|
|
|
|
|
|
|
assert isinstance(d, list)
|
|
|
|
|
|
|
|
assert len(d) == 2
|
|
|
|
|
|
|
|
assert not isinstance(d[0], list)
|
|
|
|
|
|
|
|
assert not isinstance(d[1], list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_multi_place(self):
|
|
|
|
|
|
|
|
scope = fluid.Scope()
|
|
|
|
|
|
|
|
image = fluid.data(
|
|
|
|
|
|
|
|
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
|
|
|
|
|
|
|
|
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
|
|
|
|
|
|
|
|
with fluid.scope_guard(scope):
|
|
|
|
|
|
|
|
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
|
|
|
|
|
|
|
|
dataloader = DataLoader(
|
|
|
|
|
|
|
|
dataset,
|
|
|
|
|
|
|
|
feed_list=[image, label],
|
|
|
|
|
|
|
|
num_workers=0,
|
|
|
|
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
|
|
|
|
places=[fluid.CPUPlace()] * 2,
|
|
|
|
|
|
|
|
drop_last=True,
|
|
|
|
|
|
|
|
return_list=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for d in dataloader:
|
|
|
|
|
|
|
|
assert isinstance(d, list)
|
|
|
|
|
|
|
|
assert len(d) == 2
|
|
|
|
|
|
|
|
assert isinstance(d[0], list)
|
|
|
|
|
|
|
|
assert isinstance(d[1], list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|
|
|
|
unittest.main()
|
|
|
|