fix dataloader default value and doc (#28728)

* fix dataloader. test=develop
musl/fix_failed_unittests_in_musl
Kaipeng Deng 4 years ago committed by GitHub
parent 0ed80e09fc
commit 91bab752a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -196,7 +196,7 @@ class DataLoader(object):
the key of the dict is the name of each fed variables. If
:attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default False.
in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
@ -308,7 +308,7 @@ class DataLoader(object):
dataset,
feed_list=None,
places=None,
return_list=False,
return_list=True,
batch_sampler=None,
batch_size=1,
shuffle=False,
@ -403,10 +403,10 @@ class DataLoader(object):
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
else:
if self.batch_size is None:
return len(self.dataset)
else:
if self.auto_collate_batch:
return len(self.batch_sampler)
else:
return len(self.dataset)
def __iter__(self):
if self.num_workers == 0:

@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places,
num_workers=num_workers,
batch_size=BATCH_SIZE,
return_list=False,
drop_last=True)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places,
num_workers=num_workers,
batch_size=None,
return_list=False,
drop_last=True)
exe = fluid.Executor(place=places[0])

@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places,
num_workers=num_workers,
batch_size=BATCH_SIZE,
return_list=False,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset):
labels = []
for _ in range(BATCH_SIZE):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64')
label = np.random.randint(0, self.class_num - 1,
(1, )).astype('int64')
images.append(image)
labels.append(label)
return np.stack(images, axis=0), np.stack(labels, axis=0)
@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places,
num_workers=num_workers,
batch_size=None,
return_list=False,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)

Loading…
Cancel
Save