|
|
|
@ -73,11 +73,11 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|
|
|
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
|
|
|
|
if isinstance(input_sampler, BuiltinSampler):
|
|
|
|
|
return input_sampler
|
|
|
|
|
if _is_iterable(input_sampler):
|
|
|
|
|
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
|
|
|
|
|
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
|
|
|
|
|
if isinstance(input_sampler, int):
|
|
|
|
|
return [input_sampler]
|
|
|
|
|
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
|
|
|
|
|
return SubsetSampler([input_sampler])
|
|
|
|
|
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
|
|
|
|
|
if shuffle is None:
|
|
|
|
|
if num_shards is not None:
|
|
|
|
|
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
|
|
|
@ -644,9 +644,9 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
|
indices = [indices]
|
|
|
|
|
|
|
|
|
|
for i, item in enumerate(indices):
|
|
|
|
|
if not isinstance(item, numbers.Number):
|
|
|
|
|
raise TypeError("type of indices element must be number, "
|
|
|
|
|
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
|
|
|
|
|
if not isinstance(item, int):
|
|
|
|
|
raise TypeError("SubsetSampler: Type of indices element must be int, "
|
|
|
|
|
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|