|
|
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|
|
|
|
from __future__ import division
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from .sampler import Sampler, SequenceSampler
|
|
|
|
|
from .sampler import Sampler, SequenceSampler, RandomSampler
|
|
|
|
|
from .dataset import Dataset, IterableDataset
|
|
|
|
|
|
|
|
|
|
__all__ = ["BatchSampler"]
|
|
|
|
@ -86,7 +86,6 @@ class BatchSampler(Sampler):
|
|
|
|
|
# init with sampler
|
|
|
|
|
sampler = RandomSampler(RandomDataset(100))
|
|
|
|
|
bs = BatchSampler(sampler=sampler,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
batch_size=8,
|
|
|
|
|
drop_last=True)
|
|
|
|
|
|
|
|
|
@ -118,14 +117,16 @@ class BatchSampler(Sampler):
|
|
|
|
|
"dataset should not be a paddle.io.IterableDataset"
|
|
|
|
|
assert sampler is None, \
|
|
|
|
|
"should not set both dataset and sampler"
|
|
|
|
|
self.sampler = SequenceSampler(dataset)
|
|
|
|
|
assert isinstance(shuffle, bool), \
|
|
|
|
|
"shuffle should be a boolean value, but got {}".format(type(shuffle))
|
|
|
|
|
if shuffle:
|
|
|
|
|
self.sampler = RandomSampler(dataset)
|
|
|
|
|
else:
|
|
|
|
|
self.sampler = SequenceSampler(dataset)
|
|
|
|
|
|
|
|
|
|
assert isinstance(batch_size, int) and batch_size > 0, \
|
|
|
|
|
"batch_size should be a positive integer, but got {}".format(batch_size)
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
assert isinstance(shuffle, bool), \
|
|
|
|
|
"shuffle should be a boolean value, but got {}".format(type(shuffle))
|
|
|
|
|
self.shuffle = shuffle
|
|
|
|
|
assert isinstance(drop_last, bool), \
|
|
|
|
|
"drop_last should be a boolean value, but got {}".format(type(drop_last))
|
|
|
|
|
self.drop_last = drop_last
|
|
|
|
|