!7931 Add validation to tf_reader_op

Merge pull request !7931 from ZiruiWu/batch_cpp_api_pyfunc
pull/7931/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 427599b2df

@ -22,15 +22,54 @@
#include <utility>
#include <vector>
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/util/status.h"
#include "utils/system/crc32c.h"
namespace mindspore {
namespace dataset {
namespace api {
bool ValidateFirstRowCrc(const std::string &filename) {
std::ifstream reader;
reader.open(filename);
if (!reader) {
return false;
}
// read data
int64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
// read crc from file
uint32_t masked_crc = 0;
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));
// generate crc from data
uint32_t generated_crc =
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));
return masked_crc == generated_crc;
}
// Validator for TFRecordNode
Status TFRecordNode::ValidateParams() { return Status::OK(); }
Status TFRecordNode::ValidateParams() {
std::vector<std::string> invalid_files(dataset_files_.size());
auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(),
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });
invalid_files.resize(std::distance(invalid_files.begin(), it));
std::string err_msg;
if (!invalid_files.empty()) {
err_msg += "Invalid file, the following files either cannot be opened, or are not valid tfrecord files:\n";
std::string accumulated_filenames = std::accumulate(
invalid_files.begin(), invalid_files.end(), std::string(""),
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; });
err_msg += accumulated_filenames;
}
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
// Function to build TFRecordNode
std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {

@ -398,11 +398,9 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) {
// Create a TFRecord Dataset
// Each file has two columns("image", "label") and 3 rows
std::vector<std::string> files = {
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"
};
std::vector<std::string> files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data",
datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"};
std::shared_ptr<Dataset> ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true);
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false);
@ -505,3 +503,12 @@ TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) {
// this will fail due to the incorrect schema used
EXPECT_FALSE(itr->GetNextRow(&mp));
}
TEST_F(MindDataTestPipeline, TestIncorrectTFrecordFile) {
std::string path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json";
std::shared_ptr<api::Dataset> ds = api::TFRecord({path});
EXPECT_NE(ds, nullptr);
// the tf record file is incorrect, hence validate param will fail
auto itr = ds->CreateIterator();
EXPECT_EQ(itr, nullptr);
}

Loading…
Cancel
Save