fix dataloader default value and doc (#28728)

* fix dataloader. test=develop
musl/fix_failed_unittests_in_musl
Kaipeng Deng 4 years ago committed by GitHub
parent 0ed80e09fc
commit 91bab752a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -196,7 +196,7 @@ class DataLoader(object):
the key of the dict is the name of each fed variables. If the key of the dict is the name of each fed variables. If
:attr:`return_list=True`, the return value on each device would :attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default False. in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset` to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None. and combine a batch. Default None.
@ -308,7 +308,7 @@ class DataLoader(object):
dataset, dataset,
feed_list=None, feed_list=None,
places=None, places=None,
return_list=False, return_list=True,
batch_sampler=None, batch_sampler=None,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
@ -403,10 +403,10 @@ class DataLoader(object):
if self.dataset_kind == _DatasetKind.ITER: if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported") raise ValueError("length of IterableDataset not supported")
else: else:
if self.batch_size is None: if self.auto_collate_batch:
return len(self.dataset)
else:
return len(self.batch_sampler) return len(self.batch_sampler)
else:
return len(self.dataset)
def __iter__(self): def __iter__(self):
if self.num_workers == 0: if self.num_workers == 0:

@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
return_list=False,
drop_last=True) drop_last=True)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
return_list=False,
drop_last=True) drop_last=True)
exe = fluid.Executor(place=places[0]) exe = fluid.Executor(place=places[0])

@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
return_list=False,
drop_last=True) drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset):
labels = [] labels = []
for _ in range(BATCH_SIZE): for _ in range(BATCH_SIZE):
image = np.random.random([IMAGE_SIZE]).astype('float32') image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') label = np.random.randint(0, self.class_num - 1,
(1, )).astype('int64')
images.append(image) images.append(image)
labels.append(label) labels.append(label)
return np.stack(images, axis=0), np.stack(labels, axis=0) return np.stack(images, axis=0), np.stack(labels, axis=0)
@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
return_list=False,
drop_last=True) drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)

Loading…
Cancel
Save