diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 3aca3a383f..d468ef822f 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -19,6 +19,7 @@ SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler. Users can also define a custom sampler by extending from the Sampler class. """ +import numbers import numpy as np import mindspore._c_dataengine as cde import mindspore.dataset as ds @@ -591,6 +592,20 @@ class WeightedRandomSampler(BuiltinSampler): if not isinstance(weights, list): weights = [weights] + for ind, w in enumerate(weights): + if not isinstance(w, numbers.Number): + raise TypeError("type of weights element should be number, " + "but got w[{}]={}, type={}".format(ind, w, type(w))) + + if weights == []: + raise ValueError("weights size should not be 0") + + if list(filter(lambda x: x < 0, weights)) != []: + raise ValueError("weights should not contain negative numbers") + + if list(filter(lambda x: x == 0, weights)) == weights: + raise ValueError("elements of weights should not be all zero") + if num_samples is not None: if num_samples <= 0: raise ValueError("num_samples should be a positive integer " diff --git a/tests/ut/python/dataset/test_datasets_imagefolder.py b/tests/ut/python/dataset/test_datasets_imagefolder.py index 3f0638c647..94ba52953a 100644 --- a/tests/ut/python/dataset/test_datasets_imagefolder.py +++ b/tests/ut/python/dataset/test_datasets_imagefolder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -382,6 +383,35 @@ def test_weighted_random_sampler(): logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 11 +def test_weighted_random_sampler_exception(): + """ + Test error cases for WeightedRandomSampler + """ + logger.info("Test error cases for WeightedRandomSampler") + error_msg_1 = "type of weights element should be number" + with pytest.raises(TypeError, match=error_msg_1): + weights = "" + ds.WeightedRandomSampler(weights) + + error_msg_2 = "type of weights element should be number" + with pytest.raises(TypeError, match=error_msg_2): + weights = (0.9, 0.8, 1.1) + ds.WeightedRandomSampler(weights) + + error_msg_3 = "weights size should not be 0" + with pytest.raises(ValueError, match=error_msg_3): + weights = [] + ds.WeightedRandomSampler(weights) + + error_msg_4 = "weights should not contain negative numbers" + with pytest.raises(ValueError, match=error_msg_4): + weights = [1.0, 0.1, 0.02, 0.3, -0.4] + ds.WeightedRandomSampler(weights) + + error_msg_5 = "elements of weights should not be all zero" + with pytest.raises(ValueError, match=error_msg_5): + weights = [0, 0, 0, 0, 0] + ds.WeightedRandomSampler(weights) def test_imagefolder_rename(): logger.info("Test Case rename") @@ -465,6 +495,9 @@ if __name__ == '__main__': test_weighted_random_sampler() logger.info('test_weighted_random_sampler Ended.\n') + test_weighted_random_sampler_exception() + logger.info('test_weighted_random_sampler_exception Ended.\n') + test_imagefolder_numshards() logger.info('test_imagefolder_numshards Ended.\n')