|
|
|
@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input():
|
|
|
|
|
negative_bucket_batch_sizes = [1, 2, 3, -4]
|
|
|
|
|
zero_bucket_batch_sizes = [0, 1, 2, 3]
|
|
|
|
|
|
|
|
|
|
invalid_type_pad_to_bucket_boundary = ""
|
|
|
|
|
invalid_type_drop_remainder = ""
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info:
|
|
|
|
|
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
|
|
|
|
|
assert "column_names should be a list of str" in str(info.value)
|
|
|
|
@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input():
|
|
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
|
|
|
|
|
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info:
|
|
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
|
|
|
|
None, None, invalid_type_pad_to_bucket_boundary)
|
|
|
|
|
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info:
|
|
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
|
|
|
|
|
None, None, False, invalid_type_drop_remainder)
|
|
|
|
|
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_multi_bucket_no_padding():
|
|
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
|
|
|
|