bucket_batch_sizes must be strictly positive, 0 is not a valid batch size

pull/2737/head
peilin-wang 5 years ago
parent a5c1e09469
commit 9468b49e28

@ -643,9 +643,9 @@ def check_bucket_batch_by_length(method):
if not all_int:
raise TypeError("bucket_batch_sizes should be a list of int.")
all_non_negative = all(item >= 0 for item in bucket_batch_sizes)
all_non_negative = all(item > 0 for item in bucket_batch_sizes)
if not all_non_negative:
raise ValueError("bucket_batch_sizes cannot contain any negative numbers.")
raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
if param_dict.get('pad_info') is not None:
check_type(param_dict["pad_info"], "pad_info", dict)

@ -51,6 +51,7 @@ def test_bucket_batch_invalid_input():
bucket_batch_sizes = [1, 1, 1, 1]
invalid_bucket_batch_sizes = ["1", "2", "3", "4"]
negative_bucket_batch_sizes = [1, 2, 3, -4]
zero_bucket_batch_sizes = [0, 1, 2, 3]
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
@ -82,7 +83,11 @@ def test_bucket_batch_invalid_input():
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes)
assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value)
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value)
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, zero_bucket_batch_sizes)
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value)
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)

Loading…
Cancel
Save