|
|
|
@ -374,6 +374,28 @@ def test_multi_col_map():
|
|
|
|
|
assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_exceptions_2():
|
|
|
|
|
def gen(num):
|
|
|
|
|
for i in range(num):
|
|
|
|
|
yield (np.array([i]),)
|
|
|
|
|
|
|
|
|
|
def simple_copy(colList, batchInfo):
|
|
|
|
|
return ([np.copy(arr) for arr in colList],)
|
|
|
|
|
|
|
|
|
|
def test_wrong_col_name(gen_num, batch_size):
|
|
|
|
|
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"],
|
|
|
|
|
per_batch_map=simple_copy)
|
|
|
|
|
try:
|
|
|
|
|
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
|
|
|
pass
|
|
|
|
|
return "success"
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
return str(e)
|
|
|
|
|
|
|
|
|
|
# test exception where column name is incorrect
|
|
|
|
|
assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
logger.info("Running test_var_batch_map.py test_batch_corner_cases() function")
|
|
|
|
|
test_batch_corner_cases()
|
|
|
|
@ -398,3 +420,6 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
logger.info("Running test_var_batch_map.py test_multi_col_map() function")
|
|
|
|
|
test_multi_col_map()
|
|
|
|
|
|
|
|
|
|
logger.info("Running test_var_batch_map.py test_exceptions_2() function")
|
|
|
|
|
test_exceptions_2()
|
|
|
|
|