!12765 Fix numpy input to samplers

From: @hfarahat
Reviewed-by: 
Signed-off-by:
pull/12765/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1c191a65fb

@ -16,13 +16,11 @@
General Validators. General Validators.
""" """
import inspect import inspect
import numbers
from multiprocessing import cpu_count from multiprocessing import cpu_count
import os import os
import numpy as np import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0 # POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1 POS_INT_MIN = 1
@ -389,38 +387,5 @@ def check_tensor_op(param, param_name):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
def check_sampler(sampler):
"""
Check if the sampler is of valid input.
Args:
param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler
Returns:
Exception: TypeError if error
"""
builtin = False
base_sampler = False
list_num = False
if sampler is not None:
if isinstance(sampler, samplers.BuiltinSampler):
builtin = True
elif isinstance(sampler, samplers.Sampler):
base_sampler = True
else:
# check for list of numbers
list_num = True
# subset sampler check
subset_sampler = sampler
if not isinstance(sampler, list):
subset_sampler = [sampler]
for _, item in enumerate(subset_sampler):
if not isinstance(item, numbers.Number):
list_num = False
if not (builtin or base_sampler or list_num):
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
def replace_none(value, default): def replace_none(value, default):
return value if value is not None else default return value if value is not None else default

@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
Sampler, sampler selected based on user input. Sampler, sampler selected based on user input.
""" """
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if input_sampler is not None: if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because # If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler. # we are being asked specifically to use the given sampler.
@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler): if isinstance(input_sampler, BuiltinSampler):
return input_sampler return input_sampler
if not isinstance(input_sampler, str) and _is_iterable(input_sampler): return SubsetSampler(input_sampler, num_samples)
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
if shuffle is None: if shuffle is None:
if num_shards is not None: if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler # If shuffle is not specified, sharding enabled, use distributed random sampler
@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler):
""" """
def __init__(self, indices, num_samples=None): def __init__(self, indices, num_samples=None):
if not isinstance(indices, list): def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if not isinstance(indices, str) and _is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int):
indices = [indices] indices = [indices]
else:
raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
for i, item in enumerate(indices): for i, item in enumerate(indices):
if not isinstance(item, int): if not isinstance(item, (int, np.integer)):
raise TypeError("SubsetSampler: Type of indices element must be int, " raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item))) "but got list[{}]: {}, type: {}.".format(i, item, type(item)))

@ -177,13 +177,23 @@ def test_subset_sampler():
def pipeline(): def pipeline():
sampler = ds.SubsetSampler(indices, num_samples) sampler = ds.SubsetSampler(indices, num_samples)
data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler) data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler)
data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples)
dataset_size = data.get_dataset_size() dataset_size = data.get_dataset_size()
return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size dataset_size2 = data.get_dataset_size()
res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2
return res1, res2
if exception_msg is None: if exception_msg is None:
res, size = pipeline() res, res2 = pipeline()
res, size = res
res2, size2 = res2
if not isinstance(indices, list):
indices = list(indices)
assert indices[:num_samples] == res assert indices[:num_samples] == res
assert len(indices[:num_samples]) == size assert len(indices[:num_samples]) == size
assert indices[:num_samples] == res2
assert len(indices[:num_samples]) == size2
else: else:
with pytest.raises(Exception) as error_info: with pytest.raises(Exception) as error_info:
pipeline() pipeline()
@ -205,6 +215,8 @@ def test_subset_sampler():
test_config([0, 9, 3, 2], num_samples=2) test_config([0, 9, 3, 2], num_samples=2)
test_config([0, 9, 3, 2], num_samples=5) test_config([0, 9, 3, 2], num_samples=5)
test_config(np.array([1, 2, 3]))
test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]") test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]")
test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]") test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]")
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
@ -212,6 +224,9 @@ def test_subset_sampler():
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
test_config([0, 9, 3, 2], num_samples=-1, test_config([0, 9, 3, 2], num_samples=-1,
exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)") exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)")
test_config(np.array([[1], [5]]), num_samples=10,
exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1],"
" type: <class 'numpy.ndarray'>.")
def test_sampler_chain(): def test_sampler_chain():
@ -291,8 +306,8 @@ def test_sampler_list():
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.") msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)") bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)") bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler=np.array([1, 2]), bad_pipeline(sampler=np.array([[1, 2]]),
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.") msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save