|
|
|
@ -25,6 +25,82 @@ import mindspore._c_dataengine as cde
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|
|
|
|
"""
|
|
|
|
|
Create sampler based on user input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
num_samples (int): Number of samples.
|
|
|
|
|
input_sampler (Union[Iterable, Sampler]): Sampler from user.
|
|
|
|
|
shuffle (bool): Shuffle.
|
|
|
|
|
num_shards (int): Number of shard for sharding.
|
|
|
|
|
shard_id (int): Shard ID.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Sampler, sampler selected based on user input.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _is_iterable(obj):
|
|
|
|
|
try:
|
|
|
|
|
iter(obj)
|
|
|
|
|
except TypeError:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
|
|
|
|
if number_of_samples is None:
|
|
|
|
|
return list(sampler)
|
|
|
|
|
|
|
|
|
|
if isinstance(sampler, list):
|
|
|
|
|
return sampler[:number_of_samples]
|
|
|
|
|
|
|
|
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
|
|
|
|
|
|
|
|
|
if input_sampler is not None:
|
|
|
|
|
# If the user provided a sampler, then it doesn't matter what the other args are because
|
|
|
|
|
# we are being asked specifically to use the given sampler.
|
|
|
|
|
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
|
|
|
|
|
# be None. Consider this example:
|
|
|
|
|
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
|
|
|
|
|
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
|
|
|
|
|
# In this case, the user has given different sample-related arguments that contradict each other.
|
|
|
|
|
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
|
|
|
|
|
if (isinstance(input_sampler, BuiltinSampler) and
|
|
|
|
|
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
|
|
|
|
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
|
|
|
|
if isinstance(input_sampler, BuiltinSampler):
|
|
|
|
|
return input_sampler
|
|
|
|
|
if _is_iterable(input_sampler):
|
|
|
|
|
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
|
|
|
|
|
if isinstance(input_sampler, int):
|
|
|
|
|
return [input_sampler]
|
|
|
|
|
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
|
|
|
|
|
if shuffle is None:
|
|
|
|
|
if num_shards is not None:
|
|
|
|
|
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
|
|
|
|
shuffle = True
|
|
|
|
|
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
|
|
|
|
# If shuffle is not specified, sharding disabled, use random sampler
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
return RandomSampler(replacement=True, num_samples=num_samples)
|
|
|
|
|
return RandomSampler(num_samples=num_samples)
|
|
|
|
|
if shuffle is True:
|
|
|
|
|
if num_shards is not None:
|
|
|
|
|
# If shuffle enabled, sharding enabled, use distributed random sampler
|
|
|
|
|
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
|
|
|
|
# If shuffle enabled, sharding disabled, use random sampler
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
return RandomSampler(replacement=True, num_samples=num_samples)
|
|
|
|
|
return RandomSampler(num_samples=num_samples)
|
|
|
|
|
if num_shards is not None:
|
|
|
|
|
# If shuffle disabled, sharding enabled, use distributed sequential sampler
|
|
|
|
|
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
|
|
|
|
# If shuffle disabled, sharding disabled, use sequential sampler
|
|
|
|
|
return SequentialSampler(num_samples=num_samples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BuiltinSampler:
|
|
|
|
|
"""
|
|
|
|
|
Base class for BuiltinSampler.
|
|
|
|
|