fix batch core dump when col name doesn't exist

fix ci
pull/7264/head
Zirui Wu 4 years ago
parent 59fb711fcf
commit c42a1b4a2f

@ -489,6 +489,11 @@ Status BatchOp::ComputeColMap() {
// from this point onward, per_batch_map is needed, therefore, child_map_ must be set
child_map_ = child_[0]->column_name_id_map();
// check all input columns exist
for (const auto &col : in_col_names_) {
CHECK_FAIL_RETURN_UNEXPECTED(child_map_.find(col) != child_map_.end(), "col:" + col + " doesn't exist.");
}
// following logic deals with per_batch_map
bool col_name_flag = (out_col_names_.empty() || out_col_names_ == in_col_names_); // true if col name is unchanged
@ -498,13 +503,11 @@ Status BatchOp::ComputeColMap() {
column_name_id_map_ = child_map_;
return Status::OK();
}
// column names are changed from this point onward, this map is the child_map without input cols for per_batch_map
auto child_map_no_in_col = child_map_;
for (const auto &col : in_col_names_) {
const auto itr = child_map_no_in_col.find(col);
CHECK_FAIL_RETURN_UNEXPECTED(itr != child_map_no_in_col.end(), "col:" + col + " doesn't exist.");
child_map_no_in_col.erase(itr);
child_map_no_in_col.erase(col);
}
// col names are changed

@ -374,6 +374,28 @@ def test_multi_col_map():
assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])
def test_exceptions_2():
def gen(num):
for i in range(num):
yield (np.array([i]),)
def simple_copy(colList, batchInfo):
return ([np.copy(arr) for arr in colList],)
def test_wrong_col_name(gen_num, batch_size):
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"],
per_batch_map=simple_copy)
try:
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
return "success"
except RuntimeError as e:
return str(e)
# test exception where column name is incorrect
assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2)
if __name__ == '__main__':
logger.info("Running test_var_batch_map.py test_batch_corner_cases() function")
test_batch_corner_cases()
@ -398,3 +420,6 @@ if __name__ == '__main__':
logger.info("Running test_var_batch_map.py test_multi_col_map() function")
test_multi_col_map()
logger.info("Running test_var_batch_map.py test_exceptions_2() function")
test_exceptions_2()

Loading…
Cancel
Save