fix DataLoader return same format between static & dynamic in single mode (#28176)

* fix DataLoader return same format between static & dynamic in single mode. test=develop
revert-27871-prv-conv-grad-opt
Kaipeng Deng 4 years ago committed by GitHub
parent 7db747d9e8
commit 4671d85a03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -341,6 +341,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
return self._reader.read_next_var_list()
else:
if self._return_list:
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
if len(self._places) == 1:
return self._reader.read_next_list()[0]
else:
return self._reader.read_next_list()
else:
return self._reader.read_next()

@ -170,5 +170,50 @@ class TestStaticDataLoader(unittest.TestCase):
self.assertLess(diff, 1e-2)
class TestStaticDataLoaderReturnList(unittest.TestCase):
def test_single_place(self):
scope = fluid.Scope()
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
with fluid.scope_guard(scope):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
num_workers=0,
batch_size=BATCH_SIZE,
drop_last=True,
return_list=True)
for d in dataloader:
assert isinstance(d, list)
assert len(d) == 2
assert not isinstance(d[0], list)
assert not isinstance(d[1], list)
def test_multi_place(self):
scope = fluid.Scope()
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
with fluid.scope_guard(scope):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
num_workers=0,
batch_size=BATCH_SIZE,
places=[fluid.CPUPlace()] * 2,
drop_last=True,
return_list=True)
for d in dataloader:
assert isinstance(d, list)
assert len(d) == 2
assert isinstance(d[0], list)
assert isinstance(d[1], list)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save