|
|
|
@ -49,6 +49,13 @@ class DistributedBatchSampler(BatchSampler):
|
|
|
|
|
`__len__` for BatchSampler to get sample
|
|
|
|
|
number of data source.
|
|
|
|
|
batch_size(int): sample indice number in a mini-batch indices.
|
|
|
|
|
num_replicas(int, optional): porcess number in distributed training.
|
|
|
|
|
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
|
|
|
|
|
retrieved from :code:`paddle.fluid.dygraph.parallel.ParallenEnv`.
|
|
|
|
|
Default None.
|
|
|
|
|
rank(int, optional): the rank of the current process among :attr:`num_replicas`
|
|
|
|
|
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
|
|
|
|
|
:code:`paddle.fluid.dygraph.parallel.ParallenEnv`. Default None.
|
|
|
|
|
shuffle(bool): whther to shuffle indices order before genrating
|
|
|
|
|
batch indices. Default False.
|
|
|
|
|
drop_last(bool): whether drop the last incomplete batch dataset size
|
|
|
|
@ -84,7 +91,13 @@ class DistributedBatchSampler(BatchSampler):
|
|
|
|
|
break
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size,
|
|
|
|
|
num_replicas=None,
|
|
|
|
|
rank=None,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False):
|
|
|
|
|
self.dataset = dataset
|
|
|
|
|
|
|
|
|
|
assert isinstance(batch_size, int) and batch_size > 0, \
|
|
|
|
@ -96,9 +109,21 @@ class DistributedBatchSampler(BatchSampler):
|
|
|
|
|
assert isinstance(drop_last, bool), \
|
|
|
|
|
"drop_last should be a boolean number"
|
|
|
|
|
|
|
|
|
|
if num_replicas is not None:
|
|
|
|
|
assert isinstance(num_replicas, int) and num_replicas > 0, \
|
|
|
|
|
"num_replicas should be a positive integer"
|
|
|
|
|
self.nranks = num_replicas
|
|
|
|
|
else:
|
|
|
|
|
self.nranks = ParallelEnv().nranks
|
|
|
|
|
|
|
|
|
|
if rank is not None:
|
|
|
|
|
assert isinstance(rank, int) and rank >= 0, \
|
|
|
|
|
"rank should be a non-negative integer"
|
|
|
|
|
self.local_rank = rank
|
|
|
|
|
else:
|
|
|
|
|
self.local_rank = ParallelEnv().local_rank
|
|
|
|
|
|
|
|
|
|
self.drop_last = drop_last
|
|
|
|
|
self.nranks = ParallelEnv().nranks
|
|
|
|
|
self.local_rank = ParallelEnv().local_rank
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
|
|
|
|
|
self.total_size = self.num_samples * self.nranks
|
|
|
|
|