|
|
@ -605,7 +605,7 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
Samples the elements from a sequence of indices.
|
|
|
|
Samples the elements from a sequence of indices.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
indices (list[int]): A sequence of indices.
|
|
|
|
indices (Any iterable python object but string): A sequence of indices.
|
|
|
|
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
|
|
|
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
@ -633,6 +633,13 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
|
|
|
|
|
|
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
|
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(indices, str) and validator.is_iterable(indices):
|
|
|
|
if not isinstance(indices, str) and validator.is_iterable(indices):
|
|
|
|
indices = _get_sample_ids_as_list(indices, num_samples)
|
|
|
|
indices = _get_sample_ids_as_list(indices, num_samples)
|
|
|
|
elif isinstance(indices, int):
|
|
|
|
elif isinstance(indices, int):
|
|
|
@ -645,13 +652,6 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
raise TypeError("SubsetSampler: Type of indices element must be int, "
|
|
|
|
raise TypeError("SubsetSampler: Type of indices element must be int, "
|
|
|
|
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
|
|
|
|
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
|
|
|
|
|
|
|
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.indices = indices
|
|
|
|
self.indices = indices
|
|
|
|
super().__init__(num_samples)
|
|
|
|
super().__init__(num_samples)
|
|
|
|
|
|
|
|
|
|
|
|