From 9468b49e28b904e06c2749ba41627cd2852b1f73 Mon Sep 17 00:00:00 2001 From: peilin-wang Date: Mon, 29 Jun 2020 16:34:39 -0400 Subject: [PATCH] bucket_batch_sizes must be strictly positive, 0 is not a valid batch size --- mindspore/dataset/engine/validators.py | 4 ++-- tests/ut/python/dataset/test_bucket_batch_by_length.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 9857608c19..2a0bef3b42 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -643,9 +643,9 @@ def check_bucket_batch_by_length(method): if not all_int: raise TypeError("bucket_batch_sizes should be a list of int.") - all_non_negative = all(item >= 0 for item in bucket_batch_sizes) + all_non_negative = all(item > 0 for item in bucket_batch_sizes) if not all_non_negative: - raise ValueError("bucket_batch_sizes cannot contain any negative numbers.") + raise ValueError("bucket_batch_sizes should be a list of positive numbers.") if param_dict.get('pad_info') is not None: check_type(param_dict["pad_info"], "pad_info", dict) diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index bca30723e9..4436f98e53 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -51,6 +51,7 @@ def test_bucket_batch_invalid_input(): bucket_batch_sizes = [1, 1, 1, 1] invalid_bucket_batch_sizes = ["1", "2", "3", "4"] negative_bucket_batch_sizes = [1, 2, 3, -4] + zero_bucket_batch_sizes = [0, 1, 2, 3] with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) @@ -82,7 +83,11 @@ def test_bucket_batch_invalid_input(): with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) - assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) + assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, zero_bucket_batch_sizes) + assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value) with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)