|
|
@ -223,7 +223,8 @@ class DistributedSampler(BuiltinSampler):
|
|
|
|
shard_id (int): Shard ID of the current shard within num_shards.
|
|
|
|
shard_id (int): Shard ID of the current shard within num_shards.
|
|
|
|
shuffle (bool, optional): If True, the indices are shuffled (default=True).
|
|
|
|
shuffle (bool, optional): If True, the indices are shuffled (default=True).
|
|
|
|
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
|
|
|
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
|
|
|
offset(int, optional): The starting sample ID where access to elements in the dataset begins (default=-1).
|
|
|
|
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
|
|
|
|
|
|
|
|
should be no more than num_shards.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> import mindspore.dataset as ds
|
|
|
|
>>> import mindspore.dataset as ds
|
|
|
@ -238,6 +239,7 @@ class DistributedSampler(BuiltinSampler):
|
|
|
|
ValueError: If num_shards is not positive.
|
|
|
|
ValueError: If num_shards is not positive.
|
|
|
|
ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
|
|
|
ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
|
|
|
ValueError: If shuffle is not a boolean value.
|
|
|
|
ValueError: If shuffle is not a boolean value.
|
|
|
|
|
|
|
|
ValueError: If offset is greater than num_shards.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
|
|
|
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
|
|
@ -255,6 +257,10 @@ class DistributedSampler(BuiltinSampler):
|
|
|
|
raise ValueError("num_samples should be a positive integer "
|
|
|
|
raise ValueError("num_samples should be a positive integer "
|
|
|
|
"value, but got num_samples={}".format(num_samples))
|
|
|
|
"value, but got num_samples={}".format(num_samples))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if offset > num_shards:
|
|
|
|
|
|
|
|
raise ValueError("offset should be no more than num_shards={}, "
|
|
|
|
|
|
|
|
"but got offset={}".format(num_shards, offset))
|
|
|
|
|
|
|
|
|
|
|
|
self.num_shards = num_shards
|
|
|
|
self.num_shards = num_shards
|
|
|
|
self.shard_id = shard_id
|
|
|
|
self.shard_id = shard_id
|
|
|
|
self.shuffle = shuffle
|
|
|
|
self.shuffle = shuffle
|
|
|
|