|
|
|
@ -42,6 +42,7 @@
|
|
|
|
|
#include "dataset/util/status.h"
|
|
|
|
|
#include "dataset/util/task_manager.h"
|
|
|
|
|
#include "dataset/util/wait_post.h"
|
|
|
|
|
#include "utils/system/crc32c.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
|
|
|
|
|
builder_data_schema_ = std::make_unique<DataSchema>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ValidateFirstRowCrc(const std::string &filename) {
|
|
|
|
|
std::ifstream reader;
|
|
|
|
|
reader.open(filename);
|
|
|
|
|
if (!reader) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// read data
|
|
|
|
|
int64_t record_length = 0;
|
|
|
|
|
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
|
|
|
|
|
|
|
|
|
// read crc from file
|
|
|
|
|
uint32_t masked_crc = 0;
|
|
|
|
|
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));
|
|
|
|
|
|
|
|
|
|
// generate crc from data
|
|
|
|
|
uint32_t generated_crc =
|
|
|
|
|
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));
|
|
|
|
|
|
|
|
|
|
return masked_crc == generated_crc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status TFReaderOp::Builder::ValidateInputs() const {
|
|
|
|
|
std::string err_msg;
|
|
|
|
|
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : "";
|
|
|
|
|
if (!builder_equal_rows_per_shard_) {
|
|
|
|
|
err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)
|
|
|
|
|
? "No enough tf_file files provided\n"
|
|
|
|
|
: "";
|
|
|
|
|
|
|
|
|
|
if (builder_num_workers_ <= 0) {
|
|
|
|
|
err_msg += "Number of parallel workers is smaller or equal to 0\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!builder_equal_rows_per_shard_ &&
|
|
|
|
|
builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) {
|
|
|
|
|
err_msg += "Not enough tfrecord files provided\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) {
|
|
|
|
|
err_msg += "Wrong sharding configs\n";
|
|
|
|
|
}
|
|
|
|
|
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> invalid_files(builder_dataset_files_list_.size());
|
|
|
|
|
auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(),
|
|
|
|
|
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });
|
|
|
|
|
invalid_files.resize(std::distance(invalid_files.begin(), it));
|
|
|
|
|
|
|
|
|
|
if (!invalid_files.empty()) {
|
|
|
|
|
err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n";
|
|
|
|
|
|
|
|
|
|
std::string accumulated_filenames = std::accumulate(
|
|
|
|
|
invalid_files.begin(), invalid_files.end(), std::string(""),
|
|
|
|
|
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; });
|
|
|
|
|
err_msg += accumulated_filenames;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
|
|
|
|
|
RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read));
|
|
|
|
|
rows_read++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ignore crc footer
|
|
|
|
|
(void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
|
|
|
|
|
rows_total++;
|
|
|
|
|