From 8e60cfe80027b085a1c7f90bf38ba086fb386dcc Mon Sep 17 00:00:00 2001 From: luoyang Date: Thu, 22 Oct 2020 21:15:33 +0800 Subject: [PATCH] [MD] C++ api save --- .../ccsrc/minddata/dataset/api/datasets.cc | 46 +++ .../ccsrc/minddata/dataset/api/iterator.cc | 4 +- .../dataset/engine/consumers/tree_consumer.cc | 296 +++++++++++++++++- .../dataset/engine/consumers/tree_consumer.h | 37 ++- .../ccsrc/minddata/dataset/include/datasets.h | 16 + tests/ut/cpp/dataset/c_api_dataset_save.cc | 148 +++++++++ 6 files changed, 532 insertions(+), 15 deletions(-) create mode 100644 tests/ut/cpp/dataset/c_api_dataset_save.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 3f3cce764e..1d71bc8db0 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -103,6 +103,52 @@ std::shared_ptr Dataset::CreateIterator(std::vector colum return iter; } +#ifndef ENABLE_ANDROID +// Function to create the saver, which will build and launch the execution tree and save data +bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) { + Status rc; + // Build and launch tree + auto ds = shared_from_this(); + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "CreateSaver failed." << rc; + return false; + } + + // Get SaveToDisk consumer + auto consumer = std::make_unique(dataset_path, num_files, dataset_type); + rc = consumer->ValidateParams(); + if (rc.IsError()) { + MS_LOG(ERROR) << "CreateSaver failed." << rc; + return false; + } + SaveToDisk *consumer_ = consumer.get(); + rc = consumer->Init(ds); + if (rc.IsError()) { + MS_LOG(ERROR) << "CreateSaver failed." << rc; + return false; + } + runtime_context->AssignConsumer(std::move(consumer)); + + // Save data into file + rc = consumer_->Save(); + if (rc.IsError()) { + MS_LOG(ERROR) << "Saver: Failed to save data into file. Error status: " << rc; + return false; + } + + // Shut down the data pipeline + rc = runtime_context->Terminate(); + if (rc.IsError()) { + MS_LOG(ERROR) << "Saver: Failed to shut down pipeline. Error status: " << rc; + return false; + } + + return true; +} +#endif + // Constructor Dataset::Dataset() { // Fetch some default value from config manager diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index 75991adfe2..39082e626d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -46,8 +46,8 @@ bool Iterator::GetNextRow(TensorVec *row) { // Shut down the data pipeline. void Iterator::Stop() { runtime_context->Terminate(); } -// -//// Function to build and launch the execution tree. + +// Function to build and launch the execution tree. Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { runtime_context = std::make_unique(); RETURN_IF_NOT_OK(runtime_context->Init()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 814953b15e..21b212e973 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -17,14 +17,30 @@ #include #include #include +#include #include #include #include #include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/engine/tree_adapter.h" +#ifndef ENABLE_ANDROID +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_writer.h" +#endif + namespace mindspore::dataset { +// TreeConsumer +TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique(); } + +Status TreeConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } + +// IteratorConsumer +Status IteratorConsumer::Init(std::shared_ptr d) { + return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); +} + Status IteratorConsumer::GetNextAsVector(std::vector *out) { RETURN_UNEXPECTED_IF_NULL(out); out->clear(); @@ -38,6 +54,7 @@ Status IteratorConsumer::GetNextAsVector(std::vector *out) { std::copy(res.begin(), res.end(), std::back_inserter(*out)); return Status::OK(); } + Status IteratorConsumer::GetNextAsMap(std::unordered_map *out_map) { RETURN_UNEXPECTED_IF_NULL(out_map); out_map->clear(); @@ -55,13 +72,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map return Status::OK(); } -TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique(); } - -Status IteratorConsumer::Init(std::shared_ptr d) { - return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); -} -Status TreeConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } - +// ToDevice Status ToDevice::Init(std::shared_ptr d) { // TODO(CRC): // Get device ID from children look at get_distribution in python @@ -69,4 +80,275 @@ Status ToDevice::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); } + +#ifndef ENABLE_ANDROID +// SaveToDisk +Status SaveToDisk::ValidateParams() { + if (dataset_path_.empty()) { + std::string err = "CreateSaver failed, dataset_path must not be empty"; + MS_LOG(ERROR) << err; + RETURN_STATUS_SYNTAX_ERROR(err); + } + Path dir(dataset_path_); + if (dir.IsDirectory()) { + std::string err = "CreateSaver failed, dataset_path must not be a directory"; + MS_LOG(ERROR) << err; + RETURN_STATUS_SYNTAX_ERROR(err); + } + if (num_files_ <= 0 || num_files_ > 1000) { + std::string err = "CreateSaver failed, num_files must between 1 and 1000, but got " + std::to_string(num_files_); + MS_LOG(ERROR) << err; + RETURN_STATUS_SYNTAX_ERROR(err); + } + if (dataset_type_ != "mindrecord") { + std::string err = "CreateSaver failed, only \"mindrecord\" dataset format is supported, but got " + dataset_type_; + MS_LOG(ERROR) << err; + RETURN_STATUS_SYNTAX_ERROR(err); + } + return Status::OK(); +} + +Status SaveToDisk::Save() { + std::vector file_names; + if (num_files_ == 1) { + file_names.push_back(dataset_path_); + } else { + for (int32_t i = 0; i < num_files_; i++) { + file_names.push_back(dataset_path_ + std::to_string(i)); + } + } + + auto mr_header = std::make_shared(); + auto mr_writer = std::make_unique(); + std::vector blob_fields; + if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter."); + } + + std::unordered_map column_name_id_map; + for (auto el : tree_adapter_->GetColumnNameMap()) { + std::string column_name = el.first; + std::transform(column_name.begin(), column_name.end(), column_name.begin(), + [](unsigned char c) { return ispunct(c) ? '_' : c; }); + column_name_id_map[column_name] = el.second; + } + + TensorRow row; + uint64_t mr_schema_id = 0; + bool first_loop = true; // build schema in first loop + do { + nlohmann::json row_raw_data; + std::map>> row_bin_data; + RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); + if (row.empty()) break; + if (first_loop) { + nlohmann::json mr_json; + std::vector index_fields; + RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields)); + MS_LOG(DEBUG) << "Schema of saved mindrecord: " << mr_json.dump(); + if (mindrecord::SUCCESS != + mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); + } + mr_writer->SetShardHeader(mr_header); + first_loop = false; + } + // construct data + if (!row.empty()) { // write data + RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data)); + std::shared_ptr> output_bin_data; + mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data); + std::map> raw_data; + raw_data.insert( + std::pair>(mr_schema_id, std::vector{row_raw_data})); + std::vector> bin_data; + if (nullptr != output_bin_data) { + bin_data.emplace_back(*output_bin_data); + } + mr_writer->WriteRawData(raw_data, bin_data); + } + } while (!row.empty()); + + mr_writer->Commit(); + if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); + } + return Status::OK(); +} + +Status SaveToDisk::FetchMetaFromTensorRow(const std::unordered_map &column_name_id_map, + const TensorRow &row, nlohmann::json *schema, + std::vector *index_fields) { + if (schema == nullptr) { + RETURN_STATUS_UNEXPECTED("Error: schema is NULL."); + } + if (index_fields == nullptr) { + RETURN_STATUS_UNEXPECTED("Error: index fields is NULL."); + } + if (column_name_id_map.empty()) { + RETURN_STATUS_UNEXPECTED("Error: column not found."); + } + nlohmann::json dataset_schema; + for (auto &col : column_name_id_map) { + auto idx = col.second; + auto column_name = col.first; + auto &tensor = row[idx]; + auto column_type = tensor->type(); + auto column_shape = tensor->shape(); + + std::string mr_type; + auto shapes = column_shape.AsVector(); + std::vector mr_shape(shapes.begin(), shapes.end()); + std::string el = column_type.ToString(); + dataset_schema[column_name] = el; + if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) { + std::string err_msg("Error: can not support data type: " + el); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + mr_type = mindrecord::kTypesMap.at(el); + } + if (mr_shape.empty()) { + if (mr_type == "bytes") { // map to int32 when bytes without shape. + mr_type = "int32"; + } + (*schema)[column_name] = {{"type", mr_type}}; + } else { + if (mr_type == "string") { // mindrecord can not support string with shape. + std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (mr_type == "bytes") { // ignore shape of bytes in minrecord + (*schema)[column_name] = {{"type", mr_type}}; + } else { + (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}}; + } + } + if (mr_type == "bytes" || !mr_shape.empty()) continue; + index_fields->emplace_back(column_name); // candidate of index fields + } + MS_LOG(DEBUG) << "Schema of dataset: " << dataset_schema.dump(); + return Status::OK(); +} + +Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row, + const std::unordered_map &column_name_id_map, + nlohmann::json *row_raw_data, + std::map>> *row_bin_data) { + if (row_raw_data == nullptr) { + RETURN_STATUS_UNEXPECTED("Error: row raw data is NULL."); + } + if (row_bin_data == nullptr) { + RETURN_STATUS_UNEXPECTED("Error: row bin data is NULL."); + } + if (column_name_id_map.empty()) { + RETURN_STATUS_UNEXPECTED("Error: column not found"); + } + Status s; + for (auto &col : column_name_id_map) { + auto idx = col.second; + auto column_name = col.first; + auto &tensor = row[idx]; + auto column_type = tensor->type(); + + std::unique_ptr> data_ptr; + if (column_type == DataType::DE_INT8) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT16) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT16) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT8) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT32) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT32) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT64) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_FLOAT32) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_FLOAT64) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_STRING) { + std::string_view sv; + RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {0})); // assume scalar string tensor + std::string ss(sv); + (*row_raw_data)[column_name] = std::move(ss); + continue; + } else { + RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data."); + } + RETURN_IF_NOT_OK(s); + if (data_ptr != nullptr) { + (*row_bin_data)[column_name] = std::move(data_ptr); + } + } + return Status::OK(); +} + +template +Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, + std::unique_ptr *data, std::unique_ptr> *data_ptr, + std::unique_ptr *s, bool need_convert) { + if (nullptr == src) { + RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL."); + } + *data_ptr = std::make_unique>(num_of_elements * sizeof(T)); + if (need_convert) { + auto tmp_ptr = std::make_unique>(num_of_elements * sizeof(S)); + std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin()); + auto s_ptr = reinterpret_cast(&(*(tmp_ptr->begin()))); + auto el = std::make_unique(); + for (uint32_t i = 0; i < num_of_elements; ++i) { + *el = *(s_ptr + i); + auto t_ptr = reinterpret_cast(el.get()); + for (uint32_t j = 0; j < sizeof(T); ++j) { + *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j); + } + } + } else { + std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin()); + } + if (shape.empty()) { + *data = std::make_unique(); + auto t_ptr = reinterpret_cast((*data).get()); + for (uint32_t i = 0; i < sizeof(T); ++i) { + *(t_ptr + i) = *((*data_ptr)->begin() + i); + } + } + return Status::OK(); +} +#endif + } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 00cf7e184f..34dc00eba5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -77,26 +78,50 @@ class IteratorConsumer : public TreeConsumer { int32_t num_epochs_; }; -/// Consumer that iterates over the dataset and writes it to desk -class SaveToDesk : public TreeConsumer { +#ifndef ENABLE_ANDROID +/// Consumer that iterates over the dataset and writes it to disk +class SaveToDisk : public TreeConsumer { public: /// Constructor which will call the base class default constructor. /// \param dataset_path path the the dataset /// \param num_files number of files. Default to 1 /// \param dataset_type The format of the dataset. Default to "mindrecod". - explicit SaveToDesk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") + explicit SaveToDisk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") : TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {} - /// Save the given dataset to MindRecord format on desk. This is a blocking method (i.e., after returning, all rows - /// would be written to desk) + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams(); + + /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows + /// would be written to disk) /// \return Status error code - Status Save() { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); } + Status Save(); + + protected: + /// Method to return the name of the consumer + /// \return string + std::string Name() override { return "SaveToDisk"; } private: + template + Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, + std::unique_ptr *data, std::unique_ptr> *data_ptr, + std::unique_ptr *s, bool need_convert = false); + + Status FetchMetaFromTensorRow(const std::unordered_map &column_name_id_map, + const TensorRow &row, nlohmann::json *schema, std::vector *index_fields); + + Status FetchDataFromTensorRow(const TensorRow &row, + const std::unordered_map &column_name_id_map, + nlohmann::json *row_raw_data, + std::map>> *row_bin_data); + std::string dataset_path_; int32_t num_files_; std::string dataset_type_; }; +#endif /// Consumer that iterates over the dataset and send it to a device class ToDevice : public TreeConsumer { diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index f221a322f4..cc2abfaa7f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -530,6 +530,22 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the Iterator std::shared_ptr CreateIterator(std::vector columns = {}); +#ifndef ENABLE_ANDROID + /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline + /// \note Usage restrictions: + /// 1. Supported dataset formats: 'mindrecord' only + /// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1. + /// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators + /// with random attribute in map operator. + /// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor + /// multi-dimensional string. + /// \param[in] file_name Path to dataset file + /// \param[in] num_files Number of dataset files (default=1) + /// \param[in] file_type Dataset format (default="mindrecord") + /// \return Returns true if no error encountered else false + bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord"); +#endif + /// \brief Function to create a BatchNode /// \notes Combines batch_size number of consecutive rows into batches /// \param[in] batch_size Path to the root directory that contains the dataset diff --git a/tests/ut/cpp/dataset/c_api_dataset_save.cc b/tests/ut/cpp/dataset/c_api_dataset_save.cc new file mode 100644 index 0000000000..9697b88ad0 --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_dataset_save.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/common.h" +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/transforms.h" + +using namespace mindspore::dataset::api; +using mindspore::dataset::Tensor; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +TEST_F(MindDataTestPipeline, TestSaveCifar10AndLoad) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSaveCifar10AndLoad(single mindrecord file)."; + + // Stage 1: load original dataset + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, "all", SequentialSampler(0, 10)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + std::vector> original_data; + iter->GetNextRow(&row); + + // Save original data for comparison + uint64_t i = 0; + while (row.size() != 0) { + auto label = row["label"]; + original_data.push_back(label); + MS_LOG(INFO) << "Tensor label: " << *label; + iter->GetNextRow(&row); + i++; + } + + // Expect 10 samples + EXPECT_EQ(i, 10); + // Manually terminate the pipeline + iter->Stop(); + + // Stage 2: Save data processed by the dataset pipeline + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::string temp_file = datasets_root_path_ + "/testCifar10Data/mind.mind"; + std::string temp_file_db = datasets_root_path_ + "/testCifar10Data/mind.mind.db"; + bool rc = ds->Save(temp_file); + EXPECT_EQ(rc, true); + + // Stage 3: Load dataset from file output by stage 2 + // Create a MindData Dataset + std::shared_ptr ds_minddata = MindData(temp_file, {}, SequentialSampler(0, 10)); + + // Create objects for the tensor ops + // uint32 will be casted to int64 implicitly in mindrecord file, so we have to cast it back to uint32 + std::shared_ptr type_cast = transforms::TypeCast("uint32"); + EXPECT_NE(type_cast, nullptr); + + // Create a Map operation on ds + ds_minddata = ds_minddata->Map({type_cast}, {"label"}); + EXPECT_NE(ds_minddata, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter_minddata = ds_minddata->CreateIterator(); + EXPECT_NE(iter_minddata, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row_minddata; + iter_minddata->GetNextRow(&row_minddata); + + // Check column name for each row + EXPECT_NE(row_minddata.find("image"), row_minddata.end()); + EXPECT_NE(row_minddata.find("label"), row_minddata.end()); + + // Expect the output data is same with original_data + uint64_t j = 0; + while (row_minddata.size() != 0) { + auto label = row_minddata["label"]; + EXPECT_EQ(*original_data[j], *label); + MS_LOG(INFO) << "Tensor label: " << *label; + iter_minddata->GetNextRow(&row_minddata); + j++; + } + + // Expect 10 samples + EXPECT_EQ(j, 10); + // Manually terminate the pipeline + iter_minddata->Stop(); + + // Delete temp file + EXPECT_EQ(remove(temp_file.c_str()), 0); + EXPECT_EQ(remove(temp_file_db.c_str()), 0); +} + +TEST_F(MindDataTestPipeline, TestSaveFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSaveFail with incorrect param."; + + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, "all", SequentialSampler(0, 10)); + EXPECT_NE(ds, nullptr); + + // fail with invalid dataset_path + std::string temp_file1 = ""; + bool rc1 = ds->Save(temp_file1); + EXPECT_EQ(rc1, false); + + // fail with invalid dataset_path + std::string temp_file2 = datasets_root_path_ + "/testCifar10Data/"; + bool rc2 = ds->Save(temp_file2); + EXPECT_EQ(rc2, false); + + // fail with invalid num_files + std::string temp_file3 = datasets_root_path_ + "/testCifar10Data/mind.mind"; + bool rc3 = ds->Save(temp_file3, 0); + EXPECT_EQ(rc3, false); + + // fail with invalid num_files + std::string temp_file4 = datasets_root_path_ + "/testCifar10Data/mind.mind"; + bool rc4 = ds->Save(temp_file4, 1001); + EXPECT_EQ(rc4, false); + + // fail with invalid dataset_type + std::string temp_file5 = datasets_root_path_ + "/testCifar10Data/mind.mind"; + bool rc5 = ds->Save(temp_file5, 5, "tfrecord"); + EXPECT_EQ(rc5, false); +}