From 814044963be47ce4a47ba781a041d0c371419e9c Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Wed, 28 Oct 2020 11:32:09 -0400 Subject: [PATCH] add validation check in tfrecord_node update ut --- .../ir/datasetops/source/tf_record_node.cc | 45 +++++++++++++++++-- .../dataset/c_api_dataset_tfrecord_test.cc | 17 ++++--- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 1dac924d60..2e1ca484dd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -22,15 +22,54 @@ #include #include -#include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" - +#include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/util/status.h" +#include "utils/system/crc32c.h" + namespace mindspore { namespace dataset { namespace api { + +bool ValidateFirstRowCrc(const std::string &filename) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + return false; + } + + // read data + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // read crc from file + uint32_t masked_crc = 0; + (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); + + // generate crc from data + uint32_t generated_crc = + system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); + + return masked_crc == generated_crc; +} + // Validator for TFRecordNode -Status TFRecordNode::ValidateParams() { return Status::OK(); } +Status TFRecordNode::ValidateParams() { + std::vector invalid_files(dataset_files_.size()); + auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), + [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + invalid_files.resize(std::distance(invalid_files.begin(), it)); + std::string err_msg; + if (!invalid_files.empty()) { + err_msg += "Invalid file, the following files either cannot be opened, or are not valid tfrecord files:\n"; + + std::string accumulated_filenames = std::accumulate( + invalid_files.begin(), invalid_files.end(), std::string(""), + [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); + err_msg += accumulated_filenames; + } + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} // Function to build TFRecordNode std::vector> TFRecordNode::Build() { diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc index c65590aada..bed08f6ada 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -398,11 +398,9 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) { // Create a TFRecord Dataset // Each file has two columns("image", "label") and 3 rows - std::vector files = { - datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", - datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", - datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data" - }; + std::vector files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", + datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"}; std::shared_ptr ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); EXPECT_NE(ds1, nullptr); std::shared_ptr ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); @@ -505,3 +503,12 @@ TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) { // this will fail due to the incorrect schema used EXPECT_FALSE(itr->GetNextRow(&mp)); } + +TEST_F(MindDataTestPipeline, TestIncorrectTFrecordFile) { + std::string path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json"; + std::shared_ptr ds = api::TFRecord({path}); + EXPECT_NE(ds, nullptr); + // the tf record file is incorrect, hence validate param will fail + auto itr = ds->CreateIterator(); + EXPECT_EQ(itr, nullptr); +}