|
|
|
@ -36,6 +36,11 @@ def generate_2_columns(n):
|
|
|
|
|
yield (np.array([i]), np.array([j for j in range(i + 1)]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_3_columns(n):
|
|
|
|
|
for i in range(n):
|
|
|
|
|
yield (np.array([i]), np.array([i + 1]), np.array([j for j in range(i + 1)]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_invalid_input():
|
|
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
|
|
|
|
|
|
|
|
@ -382,6 +387,48 @@ def test_bucket_batch_multi_column():
|
|
|
|
|
assert same_shape_output == same_shape_expected_output
|
|
|
|
|
assert variable_shape_output == variable_shape_expected_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_three_columns():
|
|
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_3_columns(10)), ["same_shape", "same_shape2", "variable_shape"])
|
|
|
|
|
|
|
|
|
|
column_names = ["same_shape2"]
|
|
|
|
|
bucket_boundaries = [6, 12]
|
|
|
|
|
bucket_batch_sizes = [5, 5, 1]
|
|
|
|
|
element_length_function = (lambda x: x[0] % 3)
|
|
|
|
|
pad_info = {}
|
|
|
|
|
|
|
|
|
|
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
|
|
|
|
|
bucket_batch_sizes, element_length_function,
|
|
|
|
|
pad_info)
|
|
|
|
|
|
|
|
|
|
same_shape_expected_output = [[[0], [1], [2], [3], [4]],
|
|
|
|
|
[[5], [6], [7], [8], [9]]]
|
|
|
|
|
same_shape2_expected_output = [[[1], [2], [3], [4], [5]],
|
|
|
|
|
[[6], [7], [8], [9], [10]]]
|
|
|
|
|
variable_shape_expected_output = [[[0, 0, 0, 0, 0],
|
|
|
|
|
[0, 1, 0, 0, 0],
|
|
|
|
|
[0, 1, 2, 0, 0],
|
|
|
|
|
[0, 1, 2, 3, 0],
|
|
|
|
|
[0, 1, 2, 3, 4]],
|
|
|
|
|
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0],
|
|
|
|
|
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0],
|
|
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
|
|
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
|
|
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]
|
|
|
|
|
|
|
|
|
|
same_shape_output = []
|
|
|
|
|
same_shape2_output = []
|
|
|
|
|
variable_shape_output = []
|
|
|
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
|
|
|
same_shape_output.append(data["same_shape"].tolist())
|
|
|
|
|
same_shape2_output.append(data["same_shape2"].tolist())
|
|
|
|
|
variable_shape_output.append(data["variable_shape"].tolist())
|
|
|
|
|
|
|
|
|
|
assert same_shape_output == same_shape_expected_output
|
|
|
|
|
assert same_shape2_output == same_shape2_expected_output
|
|
|
|
|
assert variable_shape_output == variable_shape_expected_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_get_dataset_size():
|
|
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
|
|
|
|
|
|
|
|
@ -402,6 +449,25 @@ def test_bucket_batch_get_dataset_size():
|
|
|
|
|
assert data_size == num_rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_invalid_column():
|
|
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
|
|
|
|
|
|
|
|
|
column_names = ["invalid_column"]
|
|
|
|
|
bucket_boundaries = [1, 2, 3]
|
|
|
|
|
bucket_batch_sizes = [3, 3, 2, 2]
|
|
|
|
|
element_length_function = (lambda x: x[0] % 4)
|
|
|
|
|
|
|
|
|
|
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
|
|
|
|
|
bucket_batch_sizes, element_length_function)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(RuntimeError) as info:
|
|
|
|
|
num_rows = 0
|
|
|
|
|
for _ in dataset.create_dict_iterator(num_epochs=1):
|
|
|
|
|
num_rows += 1
|
|
|
|
|
|
|
|
|
|
assert "BucketBatchByLength: Couldn't find the specified column in the dataset" in str(info.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_bucket_batch_invalid_input()
|
|
|
|
|
test_bucket_batch_multi_bucket_no_padding()
|
|
|
|
@ -413,4 +479,6 @@ if __name__ == '__main__':
|
|
|
|
|
test_bucket_batch_drop_remainder()
|
|
|
|
|
test_bucket_batch_default_length_function()
|
|
|
|
|
test_bucket_batch_multi_column()
|
|
|
|
|
test_bucket_batch_three_columns()
|
|
|
|
|
test_bucket_batch_get_dataset_size()
|
|
|
|
|
test_bucket_batch_invalid_column()
|
|
|
|
|