From c42a1b4a2f75d801b2bcd7709a43bbb929fd30e4 Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Tue, 13 Oct 2020 10:26:50 -0400 Subject: [PATCH] fix batch core dump when col name doesn't exist fix ci --- .../dataset/engine/datasetops/batch_op.cc | 11 +++++--- tests/ut/python/dataset/test_var_batch_map.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index d7049ba018..a142443b29 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -489,6 +489,11 @@ Status BatchOp::ComputeColMap() { // from this point onward, per_batch_map is needed, therefore, child_map_ must be set child_map_ = child_[0]->column_name_id_map(); + // check all input columns exist + for (const auto &col : in_col_names_) { + CHECK_FAIL_RETURN_UNEXPECTED(child_map_.find(col) != child_map_.end(), "col:" + col + " doesn't exist."); + } + // following logic deals with per_batch_map bool col_name_flag = (out_col_names_.empty() || out_col_names_ == in_col_names_); // true if col name is unchanged @@ -498,13 +503,11 @@ Status BatchOp::ComputeColMap() { column_name_id_map_ = child_map_; return Status::OK(); } - // column names are changed from this point onward, this map is the child_map without input cols for per_batch_map auto child_map_no_in_col = child_map_; + for (const auto &col : in_col_names_) { - const auto itr = child_map_no_in_col.find(col); - CHECK_FAIL_RETURN_UNEXPECTED(itr != child_map_no_in_col.end(), "col:" + col + " doesn't exist."); - child_map_no_in_col.erase(itr); + child_map_no_in_col.erase(col); } // col names are changed diff --git a/tests/ut/python/dataset/test_var_batch_map.py b/tests/ut/python/dataset/test_var_batch_map.py index f4c60b08cd..e0b8e34aef 100644 --- a/tests/ut/python/dataset/test_var_batch_map.py +++ b/tests/ut/python/dataset/test_var_batch_map.py @@ -374,6 +374,28 @@ def test_multi_col_map(): assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"]) +def test_exceptions_2(): + def gen(num): + for i in range(num): + yield (np.array([i]),) + + def simple_copy(colList, batchInfo): + return ([np.copy(arr) for arr in colList],) + + def test_wrong_col_name(gen_num, batch_size): + data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"], + per_batch_map=simple_copy) + try: + for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + pass + return "success" + except RuntimeError as e: + return str(e) + + # test exception where column name is incorrect + assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2) + + if __name__ == '__main__': logger.info("Running test_var_batch_map.py test_batch_corner_cases() function") test_batch_corner_cases() @@ -398,3 +420,6 @@ if __name__ == '__main__': logger.info("Running test_var_batch_map.py test_multi_col_map() function") test_multi_col_map() + + logger.info("Running test_var_batch_map.py test_exceptions_2() function") + test_exceptions_2()