From ea5da25d10444d4f86098b6216f97d56115364de Mon Sep 17 00:00:00 2001 From: jonyguo Date: Mon, 13 Apr 2020 22:07:19 +0800 Subject: [PATCH] fix: use exactly read option --- .../mindrecord/io/shard_index_generator.cc | 4 +++ mindspore/ccsrc/mindrecord/io/shard_reader.cc | 36 ++++++++----------- mindspore/ccsrc/mindrecord/io/shard_writer.cc | 25 +++++++++---- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 254ddfbb16..5a5cd7cbf3 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -512,6 +512,10 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std std::fstream in; in.open(common::SafeCStr(shard_address), std::ios::in | std::ios::binary); + if (!in.good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 2413da3737..085f148a88 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -125,13 +125,10 @@ MSRStatus ShardReader::Open() { for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::trunc | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; } MS_LOG(INFO) << "Open shard file successfully."; file_streams_.push_back(fs); @@ -146,13 +143,10 @@ MSRStatus ShardReader::Open(int n_consumer) { for (const auto &file : file_paths_) { for (int j = 0; j < n_consumer; ++j) { std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::trunc | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; } file_streams_random_[j].push_back(fs); } @@ -311,12 +305,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, std::string file_name = file_paths_[shard_id]; std::shared_ptr fs = std::make_shared(); if (!all_in_index_) { - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::out | std::ios::trunc | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - } + fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; } } sqlite3_free(errmsg); @@ -520,8 +512,8 @@ std::pair> ShardReader::GetLabelsFromBinaryFile( std::string file_name = file_paths_[shard_id]; std::vector res; std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { + fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); + if (!fs->good()) { MS_LOG(ERROR) << "File could not opened"; return {FAILED, {}}; } diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 54cf0e156b..3d4259ebbd 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -76,16 +76,27 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::trunc | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; + if (!append) { + // if not append and mindrecord file exist, return FAILED + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (fs->good()) { + MS_LOG(ERROR) << "MindRecord file already existed."; + fs->close(); + return FAILED; + } + fs->close(); + + // open the mindrecord file to write + fs->open(common::SafeCStr(file), std::ios::out | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "MindRecord file could not opened."; return FAILED; } } else { - if (!append) { - MS_LOG(ERROR) << "MindRecord file already existed"; + // open the mindrecord file to append + fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "MindRecord file could not opened for append."; return FAILED; } }