diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ec9cf58e9c..d8969eae7f 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -501,15 +501,15 @@ def check_batch(method): for k, v in param_dict.get('pad_info').items(): check_pad_info(k, v) + if (per_batch_map is None) != (input_columns is None): + # These two parameters appear together. + raise ValueError("per_batch_map and input_columns need to be passed in together.") + if input_columns is not None: check_columns(input_columns, "input_columns") if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): raise ValueError("the signature of per_batch_map should match with input columns") - if (per_batch_map is None) != (input_columns is None): - # These two parameters appear together. - raise ValueError("per_batch_map and input_columns need to be passed in together.") - if output_columns is not None: raise ValueError("output_columns is currently not implemented.") diff --git a/tests/ut/python/dataset/test_batch.py b/tests/ut/python/dataset/test_batch.py index 9e5e0139e4..4130564521 100644 --- a/tests/ut/python/dataset/test_batch.py +++ b/tests/ut/python/dataset/test_batch.py @@ -466,6 +466,16 @@ def test_batch_exception_13(): logger.info("Got an exception in DE: {}".format(str(e))) assert "column_order is currently not implemented." in str(e) +def test_batch_exception_14(): + batch_size = 2 + input_columns = ["num"] + data1 = ds.TFRecordDataset(DATA_DIR) + try: + _ = data1.batch(batch_size=batch_size, input_columns=input_columns) + except ValueError as e: + assert "per_batch_map and input_columns need to be passed in together." in str(e) + + if __name__ == '__main__': test_batch_01() test_batch_02() @@ -491,4 +501,5 @@ if __name__ == '__main__': test_batch_exception_11() test_batch_exception_12() test_batch_exception_13() + test_batch_exception_14() logger.info('\n')