diff --git a/example/convert_to_mindrecord/README.md b/example/convert_to_mindrecord/README.md new file mode 100644 index 0000000000..8d3b25e311 --- /dev/null +++ b/example/convert_to_mindrecord/README.md @@ -0,0 +1,46 @@ +# MindRecord generating guidelines + + + +- [MindRecord generating guidelines](#mindrecord-generating-guidelines) + - [Create work space](#create-work-space) + - [Implement data generator](#implement-data-generator) + - [Run data generator](#run-data-generator) + + + +## Create work space + +Assume the dataset name is 'xyz' +* Create work space from template + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + cp -r template xyz + ``` + +## Implement data generator + +Edit dictionary data generator +* Edit file + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + vi xyz/mr_api.py + ``` + + Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented +- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. +- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() + + +Tricky for parallel run +- For imagenet, one directory can be a task. +- For TFRecord with multiple files, each file can be a task. +- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) + + +## Run data generator +* run python script + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + python writer.py --mindrecord_script imagenet [...] + ``` diff --git a/example/convert_to_mindrecord/imagenet/__init__.py b/example/convert_to_mindrecord/imagenet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/imagenet/mr_api.py b/example/convert_to_mindrecord/imagenet/mr_api.py new file mode 100644 index 0000000000..e569b489b5 --- /dev/null +++ b/example/convert_to_mindrecord/imagenet/mr_api.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import os +import pickle + +######## mindrecord_schema begin ########## +mindrecord_schema = {"label": {"type": "int64"}, + "data": {"type": "bytes"}, + "file_name": {"type": "string"}} +######## mindrecord_schema end ########## + +######## Frozen code begin ########## +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: + ARG_LIST = pickle.load(mindrecord_argument_file_handle) +######## Frozen code end ########## + +parser = argparse.ArgumentParser(description='Mind record imagenet example') +parser.add_argument('--label_file', type=str, default="", help='label file') +parser.add_argument('--image_dir', type=str, default="", help='images directory') + +######## Frozen code begin ########## +args = parser.parse_args(ARG_LIST) +print(args) +######## Frozen code end ########## + + +def _user_defined_private_func(): + """ + Internal function for tasks list + + Return: + tasks list + """ + if not os.path.exists(args.label_file): + raise IOError("map file {} not exists".format(args.label_file)) + + label_dict = {} + with open(args.label_file) as file_handle: + line = file_handle.readline() + while line: + labels = line.split(" ") + label_dict[labels[1]] = labels[0] + line = file_handle.readline() + # get all the dir which are n02087046, n02094114, n02109525 + dir_paths = {} + for item in label_dict: + real_path = os.path.join(args.image_dir, label_dict[item]) + if not os.path.isdir(real_path): + print("{} dir is not exist".format(real_path)) + continue + dir_paths[item] = real_path + + if not dir_paths: + print("not valid image dir in {}".format(args.image_dir)) + return {}, {} + + dir_list = [] + for label in dir_paths: + dir_list.append(label) + return dir_list, dir_paths + + +dir_list_global, dir_paths_global = _user_defined_private_func() + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return len(dir_list_global) + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + + # get the filename, label and image binary as a dict + label = dir_list_global[task_id] + for item in os.listdir(dir_paths_global[label]): + file_name = os.path.join(dir_paths_global[label], item) + if not item.endswith("JPEG") and not item.endswith( + "jpg") and not item.endswith("jpeg"): + print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) + continue + data = {} + data["file_name"] = str(file_name) + data["label"] = int(label) + + # get the image data + image_file = open(file_name, "rb") + image_bytes = image_file.read() + image_file.close() + data["data"] = image_bytes + yield data diff --git a/example/convert_to_mindrecord/run_imagenet.sh b/example/convert_to_mindrecord/run_imagenet.sh new file mode 100644 index 0000000000..11f5dcff75 --- /dev/null +++ b/example/convert_to_mindrecord/run_imagenet.sh @@ -0,0 +1,8 @@ +#!/bin/bash +rm /tmp/imagenet/mr/* + +python writer.py --mindrecord_script imagenet \ +--mindrecord_file "/tmp/imagenet/mr/m" \ +--mindrecord_partitions 16 \ +--label_file "/tmp/imagenet/label.txt" \ +--image_dir "/tmp/imagenet/jpeg" diff --git a/example/convert_to_mindrecord/run_template.sh b/example/convert_to_mindrecord/run_template.sh new file mode 100644 index 0000000000..a4c5142c00 --- /dev/null +++ b/example/convert_to_mindrecord/run_template.sh @@ -0,0 +1,6 @@ +#!/bin/bash +rm /tmp/template/* + +python writer.py --mindrecord_script template \ +--mindrecord_file "/tmp/template/m" \ +--mindrecord_partitions 4 diff --git a/example/convert_to_mindrecord/template/__init__.py b/example/convert_to_mindrecord/template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/template/mr_api.py b/example/convert_to_mindrecord/template/mr_api.py new file mode 100644 index 0000000000..3f7d7dddf0 --- /dev/null +++ b/example/convert_to_mindrecord/template/mr_api.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import pickle + +# ## Parse argument + +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line + ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line +parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line + +# ## Your arguments below +# parser.add_argument(...) + +args = parser.parse_args(ARG_LIST) # Do NOT change this line +print(args) # Do NOT change this line + + +# ## Default mindrecord vars. Comment them unless default value has to be changed. +# mindrecord_index_fields = ['label'] +# mindrecord_header_size = 1 << 24 +# mindrecord_page_size = 1 << 25 + + +# define global vars here if necessary + + +# ####### Your code below ########## +mindrecord_schema = {"label": {"type": "int32"}} + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return 1 + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + print("task is {}".format(task_id)) + for i in range(256): + data = {} + data['label'] = i + yield data diff --git a/example/convert_to_mindrecord/writer.py b/example/convert_to_mindrecord/writer.py new file mode 100644 index 0000000000..0a9ad5c86a --- /dev/null +++ b/example/convert_to_mindrecord/writer.py @@ -0,0 +1,149 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py --mindrecord_script /YourScriptPath ... +""" +import argparse +import os +import pickle +import time +from importlib import import_module +from multiprocessing import Pool + +from mindspore.mindrecord import FileWriter + + +def _exec_task(task_id, parallel_writer=True): + """ + Execute task with specified task id + """ + print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) + imagenet_iter = mindrecord_dict_data(task_id) + batch_size = 2048 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data_list.append(imagenet_iter.__next__()) + transform_count += 1 + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Mind record writer') + parser.add_argument('--mindrecord_script', type=str, default="template", + help='path where script is saved') + + parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", + help='written file name prefix') + + parser.add_argument('--mindrecord_partitions', type=int, default=1, + help='number of written files') + + parser.add_argument('--mindrecord_workers', type=int, default=8, + help='number of parallel workers') + + args = parser.parse_known_args() + + args, other_args = parser.parse_known_args() + + print(args) + print(other_args) + + with open('mr_argument.pickle', 'wb') as file_handle: + pickle.dump(other_args, file_handle) + + try: + mr_api = import_module(args.mindrecord_script + '.mr_api') + except ModuleNotFoundError: + raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) + + num_tasks = mr_api.mindrecord_task_number() + + print("Write mindrecord ...") + + mindrecord_dict_data = mr_api.mindrecord_dict_data + + # get number of files + writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) + + start_time = time.time() + + # set the header size + try: + header_size = mr_api.mindrecord_header_size + writer.set_header_size(header_size) + except AttributeError: + print("Default header size: {}".format(1 << 24)) + + # set the page size + try: + page_size = mr_api.mindrecord_page_size + writer.set_page_size(page_size) + except AttributeError: + print("Default page size: {}".format(1 << 25)) + + # get schema + try: + mindrecord_schema = mr_api.mindrecord_schema + except AttributeError: + raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") + + # create the schema + writer.add_schema(mindrecord_schema, "mindrecord_schema") + + # add the index + try: + index_fields = mr_api.mindrecord_index_fields + writer.add_index(index_fields) + except AttributeError: + print("Default index fields: all simple fields are indexes.") + + writer.open_and_set_header() + + task_list = list(range(num_tasks)) + + # set number of workers + num_workers = args.mindrecord_workers + + if num_tasks < 1: + num_tasks = 1 + + if num_workers > num_tasks: + num_workers = num_tasks + + if num_tasks > 1: + with Pool(num_workers) as p: + p.map(_exec_task, task_list) + else: + _exec_task(0, False) + + ret = writer.commit() + + os.remove("{}".format("mr_argument.pickle")) + + end_time = time.time() + print("--------------------------------------------") + print("END. Total time: {}".format(end_time - start_time)) + print("--------------------------------------------") diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 338a17ac2d..8718e9b871 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -75,12 +75,9 @@ void BindShardWriter(py::module *m) { .def("set_header_size", &ShardWriter::set_header_size) .def("set_page_size", &ShardWriter::set_page_size) .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", - (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool)) & - ShardWriter::WriteRawData) - .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map> &, - std::map> &, bool)) & - ShardWriter::WriteRawData) + .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, + vector> &, bool, bool)) & + ShardWriter::WriteRawData) .def("commit", &ShardWriter::Commit); } diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index ca4d3bd66f..70cfcdb6b7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -121,6 +121,10 @@ class ShardHeader { std::vector SerializeHeader(); + MSRStatus PagesToFile(const std::string dump_file_name); + + MSRStatus FileToPages(const std::string dump_file_name); + private: MSRStatus InitializeHeader(const std::vector &headers); diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 6a22f07700..78a434fc97 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -18,6 +18,7 @@ #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include +#include #include #include #include @@ -87,7 +88,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -95,7 +96,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -103,7 +104,8 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true); + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); private: /// \brief write shard header data to disk @@ -201,7 +203,34 @@ class ShardWriter { MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, std::map &err_raw_data); + /// \brief Lock writer and save pages info + int LockWriter(bool parallel_writer = false); + + /// \brief Unlock writer and save pages info + MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + + /// \brief Check raw data before writing + MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); + + /// \brief Get full path from file name + MSRStatus GetFullPathFromFileName(const std::vector &paths); + + /// \brief Open files + MSRStatus OpenDataFiles(bool append); + + /// \brief Remove lock file + MSRStatus RemoveLockFile(); + + /// \brief Remove lock file + MSRStatus InitLockFile(); + private: + const std::string kLockFileSuffix = "_Locker"; + const std::string kPageFileSuffix = "_Pages"; + std::string lock_file_; // lock file for parallel run + std::string pages_file_; // temporary file of pages info for parallel run + int shard_count_; // number of files uint64_t header_size_; // header size uint64_t page_size_; // page size @@ -211,7 +240,7 @@ class ShardWriter { std::vector raw_data_size_; // Raw data size std::vector blob_data_size_; // Blob data size - std::vector file_paths_; // file paths + std::vector file_paths_; // file paths std::vector> file_streams_; // file handles std::shared_ptr shard_header_; // shard headers diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 5a5cd7cbf3..dc2743cdc7 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -520,13 +520,16 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); if (sql.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw SQL failed"; return FAILED; } auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); if (data.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw data failed"; return FAILED; } if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + MS_LOG(ERROR) << "Execute SQL failed"; return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 864e6697d0..ac95e622c9 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -40,17 +40,7 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - +MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { // Get full path from file name for (const auto &path : paths) { if (!CheckIsValidUtf8(path)) { @@ -60,7 +50,7 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func failed"; + MS_LOG(ERROR) << "Secure func failed"; return FAILED; } #if defined(_WIN32) || defined(_WIN64) @@ -82,7 +72,10 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) #endif file_paths_.emplace_back(string(resolved_path)); } + return SUCCESS; +} +MSRStatus ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); @@ -116,6 +109,67 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) return SUCCESS; } +MSRStatus ShardWriter::RemoveLockFile() { + // Remove temporary file + int ret = std::remove(pages_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove page file."; + } + + ret = std::remove(lock_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove lock file."; + } + return SUCCESS; +} + +MSRStatus ShardWriter::InitLockFile() { + if (file_paths_.size() == 0) { + MS_LOG(ERROR) << "File path not initialized."; + return FAILED; + } + + lock_file_ = file_paths_[0] + kLockFileSuffix; + pages_file_ = file_paths_[0] + kPageFileSuffix; + + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + + // Get full path from file name + if (GetFullPathFromFileName(paths) == FAILED) { + MS_LOG(ERROR) << "Get full path from file name failed."; + return FAILED; + } + + // Open files + if (OpenDataFiles(append) == FAILED) { + MS_LOG(ERROR) << "Open data files failed."; + return FAILED; + } + + // Init lock file + if (InitLockFile() == FAILED) { + MS_LOG(ERROR) << "Init lock file failed."; + return FAILED; + } + return SUCCESS; +} + MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; @@ -143,11 +197,28 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { } MSRStatus ShardWriter::Commit() { + // Read pages file + std::ifstream page_file(pages_file_.c_str()); + if (page_file.good()) { + page_file.close(); + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return FAILED; + } + } + if (WriteShardHeader() == FAILED) { MS_LOG(ERROR) << "Write metadata failed"; return FAILED; } MS_LOG(INFO) << "Write metadata successfully."; + + // Remove lock file + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove lock file failed."; + return FAILED; + } + return SUCCESS; } @@ -455,15 +526,65 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { +int ShardWriter::LockWriter(bool parallel_writer) { + if (!parallel_writer) { + return 0; + } + const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); + if (fd >= 0) { + flock(fd, LOCK_EX); + } else { + MS_LOG(ERROR) << "Shard writer failed when locking file"; + return -1; + } + + // Open files + file_streams_.clear(); + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); + if (fs->fail()) { + MS_LOG(ERROR) << "File could not opened"; + return -1; + } + file_streams_.push_back(fs); + } + + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return -1; + } + return fd; +} + +MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { + if (!parallel_writer) { + return SUCCESS; + } + + if (shard_header_->PagesToFile(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Write pages to file failed"; + return FAILED; + } + + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } + + flock(fd, LOCK_UN); + close(fd); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { // check the free disk size auto st_space = GetDiskSize(file_paths_[0], kFreeSize); if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } - // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); @@ -479,10 +600,29 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d MS_LOG(ERROR) << "Validate raw data failed"; return FAILED; } + *schema_count = std::get<1>(v); + *row_count = std::get<2>(v); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { + // Lock Writer if loading data parallel + int fd = LockWriter(parallel_writer); + if (fd < 0) { + MS_LOG(ERROR) << "Lock writer failed"; + return FAILED; + } // Get the count of schemas and rows - int schema_count = std::get<1>(v); - int row_count = std::get<2>(v); + int schema_count = 0; + int row_count = 0; + + // Serialize raw data + if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { + MS_LOG(ERROR) << "Check raw data failed"; + return FAILED; + } if (row_count == kInt0) { MS_LOG(INFO) << "Raw data size is 0."; @@ -516,11 +656,17 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d } MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; + if (UnlockWriter(fd, parallel_writer) == FAILED) { + MS_LOG(ERROR) << "Unlock writer failed"; + return FAILED; + } + return SUCCESS; } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign) { + std::map> &blob_data, bool sign, + bool parallel_writer) { std::map> raw_data_json; std::map> blob_data_json; @@ -554,11 +700,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; return FAILED; } - return WriteRawData(raw_data_json, bin_blob_data, sign); + return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign) { + vector> &blob_data, bool sign, bool parallel_writer) { std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &pair) { @@ -568,7 +714,7 @@ MSRStatus ShardWriter::WriteRawData(std::map> [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(pair.first, std::move(json_raw_data)); }); - return WriteRawData(raw_data_json, blob_data, sign); + return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); } MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 57b2e5fa9e..26008e3ca9 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -677,5 +677,43 @@ std::pair, MSRStatus> ShardHeader::GetStatisticByID( } return std::make_pair(statistics_.at(statistic_id), SUCCESS); } + +MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { + // write header content to file, dump whatever is in the file before + std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); + if (page_out_handle.fail()) { + MS_LOG(ERROR) << "Failed in opening page file"; + return FAILED; + } + + auto pages = SerializePage(); + for (const auto &shard_pages : pages) { + page_out_handle << shard_pages << "\n"; + } + + page_out_handle.close(); + return SUCCESS; +} + +MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { + for (auto &v : pages_) { // clean pages + v.clear(); + } + // attempt to open the file contains the page in json + std::ifstream page_in_handle(dump_file_name.c_str()); + + if (!page_in_handle.good()) { + MS_LOG(INFO) << "No page file exists."; + return SUCCESS; + } + + std::string line; + while (std::getline(page_in_handle, line)) { + ParsePage(json::parse(line)); + } + + page_in_handle.close(); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 90bca48038..62bcc2df79 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -200,13 +200,24 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def write_raw_data(self, raw_data): + def open_and_set_header(self): + """ + Open writer and set header + + """ + if not self._writer.is_open: + self._writer.open(self._paths) + if not self._writer.get_shard_header(): + self._writer.set_shard_header(self._header) + + def write_raw_data(self, raw_data, parallel_writer=False): """ Write raw data and generate sequential pair of MindRecord File and \ validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. + parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). Raises: ParamTypeError: If index field is invalid. @@ -225,7 +236,7 @@ class FileWriter: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True) + return self._writer.write_raw_data(raw_data, True, parallel_writer) def set_header_size(self, header_size): """ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0ef23d4ce6..0913201861 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -135,7 +135,7 @@ class ShardWriter: def get_shard_header(self): return self._header - def write_raw_data(self, data, validate=True): + def write_raw_data(self, data, validate=True, parallel_writer=False): """ Write raw data of cv dataset. @@ -145,6 +145,7 @@ class ShardWriter: Args: data (list[dict]): List of raw data. validate (bool, optional): verify data according schema if it equals to True. + parallel_writer (bool, optional): Load data parallel if it equals to True. Returns: MSRStatus, SUCCESS or FAILED. @@ -165,7 +166,7 @@ class ShardWriter: if row_raw: raw_data.append(row_raw) raw_data = {0: raw_data} if raw_data else {} - ret = self._writer.write_raw_data(raw_data, blob_data, validate) + ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to write dataset.") raise MRMWriteDatasetError