diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 0764d7e0ad..a72be1f703 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -42,6 +42,7 @@ #include "dataset/util/status.h" #include "dataset/util/task_manager.h" #include "dataset/util/wait_post.h" +#include "utils/system/crc32c.h" namespace mindspore { namespace dataset { @@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder() builder_data_schema_ = std::make_unique(); } +bool ValidateFirstRowCrc(const std::string &filename) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + return false; + } + + // read data + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // read crc from file + uint32_t masked_crc = 0; + (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); + + // generate crc from data + uint32_t generated_crc = + system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); + + return masked_crc == generated_crc; +} + Status TFReaderOp::Builder::ValidateInputs() const { std::string err_msg; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : ""; - if (!builder_equal_rows_per_shard_) { - err_msg += builder_dataset_files_list_.size() < static_cast(builder_num_devices_) - ? "No enough tf_file files provided\n" - : ""; + + if (builder_num_workers_ <= 0) { + err_msg += "Number of parallel workers is smaller or equal to 0\n"; + } + + if (!builder_equal_rows_per_shard_ && + builder_dataset_files_list_.size() < static_cast(builder_num_devices_)) { + err_msg += "Not enough tfrecord files provided\n"; + } + + if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { + err_msg += "Wrong sharding configs\n"; } - err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + + std::vector invalid_files(builder_dataset_files_list_.size()); + auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), + [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + invalid_files.resize(std::distance(invalid_files.begin(), it)); + + if (!invalid_files.empty()) { + err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; + + std::string accumulated_filenames = std::accumulate( + invalid_files.begin(), invalid_files.end(), std::string(""), + [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); + err_msg += accumulated_filenames; + } + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } @@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); rows_read++; } + // ignore crc footer (void)reader.ignore(static_cast(sizeof(int32_t))); rows_total++; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 57ce07b927..5b3c0f1503 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -926,13 +926,22 @@ class SourceDataset(Dataset): List, files. """ - def flat(lists): - return list(np.array(lists).flatten()) - if not isinstance(patterns, list): patterns = [patterns] - file_list = flat([glob.glob(file, recursive=True) for file in patterns]) + file_list = [] + unmatched_patterns = [] + for pattern in patterns: + matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)] + + if matches: + file_list.extend(matches) + else: + unmatched_patterns.append(pattern) + + if unmatched_patterns: + raise ValueError("The following patterns did not match any files: ", unmatched_patterns) + if file_list: # not empty return file_list raise ValueError("The list of path names matching the patterns is empty.") diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 5fb1f4e909..9b312296d8 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) { TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true); ASSERT_EQ(total_rows, 60); } + +TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) { + // Start with an empty execution tree + auto my_tree = std::make_shared(); + + std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt"; + std::string nonexistent_file = "this/file/doesnt/exist"; + + std::shared_ptr my_tfreader_op; + TFReaderOp::Builder builder; + builder.SetDatasetFilesList({invalid_file, valid_file, schema_file}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16); + + std::unique_ptr schema = std::make_unique(); + schema->LoadSchemaFile(schema_file, {}); + builder.SetDataSchema(std::move(schema)); + + Status rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(!rc.IsOk()); + + builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16); + + schema = std::make_unique(); + schema->LoadSchemaFile(schema_file, {}); + builder.SetDataSchema(std::move(schema)); + + rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(!rc.IsOk()); +} diff --git a/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data index c5b5440cff..f3bb23af51 100644 Binary files a/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data and b/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data differ diff --git a/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data index c5b5440cff..f3bb23af51 100644 Binary files a/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data and b/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data differ diff --git a/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data index c5b5440cff..f3bb23af51 100644 Binary files a/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data and b/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data differ diff --git a/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt b/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt new file mode 100644 index 0000000000..3307b71672 --- /dev/null +++ b/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt @@ -0,0 +1 @@ +this is just a text file, not a valid tfrecord file. diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index 6de14df34e..3add50e1cb 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -32,7 +32,7 @@ def test_case_tf_shape(): ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds1.batch(2) for data in ds1.create_dict_iterator(): - print(data) + logger.info(data) output_shape = ds1.output_shapes() assert (len(output_shape[-1]) == 1) @@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list(): a = row["col_sint32"] assert "col_sint32" in str(info.value) +def test_case_invalid_files(): + valid_file = "../data/dataset/testTFTestAllTypes/test.data" + invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" + files = [invalid_file, valid_file, SCHEMA_FILE] + + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + + with pytest.raises(RuntimeError) as info: + row = data.create_dict_iterator().get_next() + assert "cannot be opened" in str(info.value) + assert "not valid tfrecord files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file in str(info.value) + assert SCHEMA_FILE in str(info.value) + + nonexistent_file = "this/file/does/not/exist" + files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file] + + with pytest.raises(ValueError) as info: + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + assert "did not match any files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file not in str(info.value) + assert SCHEMA_FILE not in str(info.value) + assert nonexistent_file in str(info.value) + if __name__ == '__main__': test_case_tf_shape() test_case_tf_file() @@ -212,3 +238,4 @@ if __name__ == '__main__': test_tf_record_schema() test_tf_record_shuffle() test_tf_shard_equal_rows() + test_case_invalid_files()