diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 8ce9e3fdfc..13d68b01f7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -280,9 +280,6 @@ class ShardReader { /// \brief read one row by one task TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); - /// \brief get all the column names by schema - vector GetAllColumns(); - /// \brief get one row from buffer in block-reader mode std::shared_ptr, json>>> GetRowFromBuffer(int bufId, int rowId); @@ -308,7 +305,6 @@ class ShardReader { uint64_t page_size_; // page size int shard_count_; // number of shards std::shared_ptr shard_header_; // shard header - bool nlp_ = false; // NLP data std::vector database_paths_; // sqlite handle list std::vector file_paths_; // file paths diff --git a/mindspore/ccsrc/mindrecord/include/shard_segment.h b/mindspore/ccsrc/mindrecord/include/shard_segment.h index 9ffb7aee88..12497a5ace 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_segment.h +++ b/mindspore/ccsrc/mindrecord/include/shard_segment.h @@ -90,8 +90,6 @@ class ShardSegment : public ShardReader { std::string CleanUp(std::string fieldName); - std::tuple, json> GetImageLabel(std::vector images, json label); - std::pair> PackImages(int group_id, int shard_id, std::vector offset); std::vector candidate_category_fields_; diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index ed5af0e6af..1f0a6b8dce 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -433,7 +433,6 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { } ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { - std::lock_guard lck(shard_locker_); const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); if (SUCCESS != ret.first) { return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); @@ -455,7 +454,6 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, const std::vector &columns) { - std::lock_guard lck(shard_locker_); const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); if (SUCCESS != ret.first) { return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); @@ -532,13 +530,6 @@ std::vector> ShardReader::GetImageOffset(int page_id, int return res; } -void ShardReader::CheckNlp() { - nlp_ = false; - return; -} - -bool ShardReader::GetNlpFlag() { return nlp_; } - std::pair> ShardReader::GetBlobFields() { std::vector blob_fields; for (auto &p : GetShardHeader()->GetSchemas()) { @@ -547,7 +538,7 @@ std::pair> ShardReader::GetBlobFields() { blob_fields.assign(fields.begin(), fields.end()); break; } - return std::make_pair(nlp_ ? kNLP : kCV, blob_fields); + return std::make_pair(kCV, blob_fields); } void ShardReader::CheckIfColumnInIndex(const std::vector &columns) { @@ -828,18 +819,11 @@ MSRStatus ShardReader::Open(const std::vector &file_paths, bool loa if (n_consumer < kMinConsumerCount) { n_consumer = kMinConsumerCount; } - CheckNlp(); - - // dead code - if (nlp_) { - selected_columns_ = selected_columns; - } else { - vector blob_fields = GetBlobFields().second; - for (unsigned int i = 0; i < selected_columns.size(); ++i) { - if (!std::any_of(blob_fields.begin(), blob_fields.end(), - [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { - selected_columns_.push_back(selected_columns[i]); - } + vector blob_fields = GetBlobFields().second; + for (unsigned int i = 0; i < selected_columns.size(); ++i) { + if (!std::any_of(blob_fields.begin(), blob_fields.end(), + [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { + selected_columns_.push_back(selected_columns[i]); } } selected_columns_ = selected_columns; @@ -895,7 +879,6 @@ MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool l if (Open(n_consumer) == FAILED) { return FAILED; } - CheckNlp(); // Initialize argument shard_count_ = static_cast(file_paths_.size()); n_consumer_ = n_consumer; @@ -918,10 +901,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { interrupt_ = true; return FAILED; } - MS_LOG(INFO) << "Launching read threads."; - if (isSimpleReader) return SUCCESS; - // Start provider consumer threads thread_set_ = std::vector(n_consumer_); if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) { @@ -940,29 +920,9 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { return SUCCESS; } -vector ShardReader::GetAllColumns() { - vector columns; - if (nlp_) { - for (auto &c : selected_columns_) { - for (auto &p : GetShardHeader()->GetSchemas()) { - auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. - for (auto it = schema.begin(); it != schema.end(); ++it) { - if (it.key() == c) { - columns.push_back(c); - } - } - } - } - } else { - columns = selected_columns_; - } - return columns; -} - MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, const std::vector> &operators) { - vector columns = GetAllColumns(); - CheckIfColumnInIndex(columns); + CheckIfColumnInIndex(selected_columns_); for (const auto &rg : row_group_summary) { auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); @@ -974,9 +934,7 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, const std::shared_ptr &op) { - vector columns = GetAllColumns(); - CheckIfColumnInIndex(columns); - + CheckIfColumnInIndex(selected_columns_); auto category_op = std::dynamic_pointer_cast(op); auto categories = category_op->GetCategories(); int64_t num_elements = category_op->GetNumElements(); @@ -1011,7 +969,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector(rg); auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns); + auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); if (SUCCESS != std::get<0>(details)) { return FAILED; } @@ -1037,10 +995,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, const std::vector> &operators) { - vector columns = GetAllColumns(); - CheckIfColumnInIndex(columns); + CheckIfColumnInIndex(selected_columns_); - auto ret = ReadAllRowGroup(columns); + auto ret = ReadAllRowGroup(selected_columns_); if (std::get<0>(ret) != SUCCESS) { return FAILED; } @@ -1202,28 +1159,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ // Deliver batch data to output map std::vector, json>> batch; - if (nlp_) { - // dead code - json blob_fields = json::from_msgpack(images_with_exact_columns); - - json merge; - if (selected_columns_.size() > 0) { - for (auto &col : selected_columns_) { - if (blob_fields.find(col) != blob_fields.end()) { - merge[col] = blob_fields[col]; - } - } - } else { - merge = blob_fields; - } - auto label_json = std::get<2>(task); - if (label_json != nullptr) { - merge.update(label_json); - } - batch.emplace_back(std::vector{}, std::move(merge)); - } else { - batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); - } + batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); return std::make_pair(SUCCESS, std::move(batch)); } diff --git a/mindspore/ccsrc/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/mindrecord/io/shard_segment.cc index d6536996ba..86c79ca05a 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_segment.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_segment.cc @@ -296,8 +296,7 @@ std::pair, json>>> ShardS if (SUCCESS != ret1.first) { return {FAILED, std::vector, json>>{}}; } - auto imageLabel = GetImageLabel(ret1.second, labels[i]); - page.emplace_back(std::move(std::get<0>(imageLabel)), std::move(std::get<1>(imageLabel))); + page.emplace_back(std::move(ret1.second), std::move(labels[i])); } } } @@ -371,35 +370,7 @@ std::pair> ShardSegment::GetBlobFields() { blob_fields.assign(fields.begin(), fields.end()); break; } - return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields); -} - -std::tuple, json> ShardSegment::GetImageLabel(std::vector images, json label) { - if (GetNlpFlag()) { - vector columns; - for (auto &p : GetShardHeader()->GetSchemas()) { - auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. - auto schema_items = schema.items(); - using it_type = decltype(schema_items.begin()); - std::transform(schema_items.begin(), schema_items.end(), std::back_inserter(columns), - [](it_type item) { return item.key(); }); - } - - json blob_fields = json::from_msgpack(images); - json merge; - if (columns.size() > 0) { - for (auto &col : columns) { - if (blob_fields.find(col) != blob_fields.end()) { - merge[col] = blob_fields[col]; - } - } - } else { - merge = blob_fields; - } - merge.update(label); - return std::make_tuple(std::vector{}, merge); - } - return std::make_tuple(images, label); + return std::make_pair(kCV, blob_fields); } std::string ShardSegment::CleanUp(std::string field_name) {