From 9c8af0d1cf98ea229f29b3afd288d7baf6c56cba Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Fri, 14 Aug 2020 17:38:16 -0400 Subject: [PATCH] C++ API: Provide validate param functions --- .../ccsrc/minddata/dataset/api/datasets.cc | 136 +++++++++--------- .../ccsrc/minddata/dataset/include/datasets.h | 35 ++--- .../dataset/c_api_dataset_filetext_test.cc | 3 +- 3 files changed, 93 insertions(+), 81 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 01be88b503..eb68832675 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -105,7 +105,7 @@ Dataset::Dataset() { // Function to create a CelebADataset. std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &dataset_type, - const std::shared_ptr &sampler, const bool &decode, + const std::shared_ptr &sampler, bool decode, const std::set &extensions) { auto ds = std::make_shared(dataset_dir, dataset_type, sampler, decode, extensions); @@ -114,7 +114,7 @@ std::shared_ptr CelebA(const std::string &dataset_dir, const std: } // Function to create a Cifar10Dataset. -std::shared_ptr Cifar10(const std::string &dataset_dir, std::shared_ptr sampler) { +std::shared_ptr Cifar10(const std::string &dataset_dir, const std::shared_ptr &sampler) { auto ds = std::make_shared(dataset_dir, sampler); // Call derived class validation method. @@ -122,7 +122,7 @@ std::shared_ptr Cifar10(const std::string &dataset_dir, std::sha } // Function to create a Cifar100Dataset. -std::shared_ptr Cifar100(const std::string &dataset_dir, std::shared_ptr sampler) { +std::shared_ptr Cifar100(const std::string &dataset_dir, const std::shared_ptr &sampler) { auto ds = std::make_shared(dataset_dir, sampler); // Call derived class validation method. @@ -131,8 +131,8 @@ std::shared_ptr Cifar100(const std::string &dataset_dir, std::s // Function to create a CLUEDataset. std::shared_ptr CLUE(const std::vector &clue_files, const std::string &task, - const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards, - int shard_id) { + const std::string &usage, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id) { auto ds = std::make_shared(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id); // Call derived class validation method. @@ -150,9 +150,10 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::str } // Function to create a ImageFolderDataset. -std::shared_ptr ImageFolder(std::string dataset_dir, bool decode, - std::shared_ptr sampler, std::set extensions, - std::map class_indexing) { +std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode, + const std::shared_ptr &sampler, + const std::set &extensions, + const std::map &class_indexing) { // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. bool recursive = false; @@ -164,7 +165,7 @@ std::shared_ptr ImageFolder(std::string dataset_dir, bool de } // Function to create a MnistDataset. -std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler) { +std::shared_ptr Mnist(const std::string &dataset_dir, const std::shared_ptr &sampler) { auto ds = std::make_shared(dataset_dir, sampler); // Call derived class validation method. @@ -181,7 +182,7 @@ std::shared_ptr operator+(const std::shared_ptr &dataset } // Function to create a TextFileDataset. -std::shared_ptr TextFile(std::vector dataset_files, int32_t num_samples, +std::shared_ptr TextFile(const std::vector &dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { auto ds = std::make_shared(dataset_files, num_samples, shuffle, num_shards, shard_id); @@ -191,9 +192,9 @@ std::shared_ptr TextFile(std::vector dataset_files // Function to create a VOCDataset. std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode, - const std::map &class_index, bool decode, - std::shared_ptr sampler) { - auto ds = std::make_shared(dataset_dir, task, mode, class_index, decode, sampler); + const std::map &class_indexing, bool decode, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, task, mode, class_indexing, decode, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -402,16 +403,57 @@ Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, in return Status::OK(); } -// Helper function to validate dataset params -bool ValidateCommonDatasetParams(std::string dataset_dir) { +// Helper function to validate dataset directory parameter +bool ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { if (dataset_dir.empty()) { - MS_LOG(ERROR) << "No dataset path is specified"; + MS_LOG(ERROR) << dataset_name << ": dataset_dir is not specified."; return false; } + + Path dir(dataset_dir); + if (!dir.IsDirectory()) { + MS_LOG(ERROR) << dataset_name << ": dataset_dir: [" << dataset_dir << "] is an invalid directory path."; + return false; + } + if (access(dataset_dir.c_str(), R_OK) == -1) { - MS_LOG(ERROR) << "No access to specified dataset path: " << dataset_dir; + MS_LOG(ERROR) << dataset_name << ": No access to specified dataset path: " << dataset_dir; return false; } + + return true; +} + +// Helper function to validate dataset dataset files parameter +bool ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector &dataset_files) { + if (dataset_files.empty()) { + MS_LOG(ERROR) << dataset_name << ": dataset_files is not specified."; + return false; + } + + for (auto f : dataset_files) { + Path dataset_file(f); + if (!dataset_file.Exists()) { + MS_LOG(ERROR) << dataset_name << ": dataset file: [" << f << "] is invalid or does not exist."; + return false; + } + } + + return true; +} + +// Helper function to validate dataset num_shards and shard_id parameters +bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { + if (num_shards <= 0) { + MS_LOG(ERROR) << dataset_name << ": Invalid num_shards: " << num_shards; + return false; + } + + if (shard_id < 0 || shard_id >= num_shards) { + MS_LOG(ERROR) << dataset_name << ": Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; + return false; + } + return true; } @@ -431,9 +473,7 @@ CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string & extensions_(extensions) {} bool CelebADataset::ValidateParams() { - Path dir(dataset_dir_); - if (!dir.IsDirectory()) { - MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; + if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) { return false; } std::set dataset_type_list = {"all", "train", "valid", "test"}; @@ -471,7 +511,7 @@ std::vector> CelebADataset::Build() { Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr sampler) : dataset_dir_(dataset_dir), sampler_(sampler) {} -bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } +bool Cifar10Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_); } // Function to build CifarOp for Cifar10 std::vector> Cifar10Dataset::Build() { @@ -500,7 +540,7 @@ std::vector> Cifar10Dataset::Build() { Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr sampler) : dataset_dir_(dataset_dir), sampler_(sampler) {} -bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } +bool Cifar100Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_); } // Function to build CifarOp for Cifar100 std::vector> Cifar100Dataset::Build() { @@ -529,7 +569,7 @@ std::vector> Cifar100Dataset::Build() { // Constructor for CLUEDataset CLUEDataset::CLUEDataset(const std::vector clue_files, std::string task, std::string usage, - int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id) + int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) : dataset_files_(clue_files), task_(task), usage_(usage), @@ -539,19 +579,10 @@ CLUEDataset::CLUEDataset(const std::vector clue_files, std::string shard_id_(shard_id) {} bool CLUEDataset::ValidateParams() { - if (dataset_files_.empty()) { - MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified."; + if (!ValidateDatasetFilesParam("CLUEDataset", dataset_files_)) { return false; } - for (auto f : dataset_files_) { - Path clue_file(f); - if (!clue_file.Exists()) { - MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist"; - return false; - } - } - std::vector task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; std::vector usage_list = {"train", "test", "eval"}; @@ -570,13 +601,7 @@ bool CLUEDataset::ValidateParams() { return false; } - if (num_shards_ <= 0) { - MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_; - return false; - } - - if (shard_id_ < 0 || shard_id_ >= num_shards_) { - MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; + if (!ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)) { return false; } @@ -734,9 +759,7 @@ CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &anno : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} bool CocoDataset::ValidateParams() { - Path dir(dataset_dir_); - if (!dir.IsDirectory()) { - MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; + if (!ValidateDatasetDirParam("CocoDataset", dataset_dir_)) { return false; } Path annotation_file(annotation_file_); @@ -829,7 +852,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std class_indexing_(class_indexing), exts_(extensions) {} -bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } +bool ImageFolderDataset::ValidateParams() { return ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_); } std::vector> ImageFolderDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create @@ -857,7 +880,7 @@ std::vector> ImageFolderDataset::Build() { MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr sampler) : dataset_dir_(dataset_dir), sampler_(sampler) {} -bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } +bool MnistDataset::ValidateParams() { return ValidateDatasetDirParam("MnistDataset", dataset_dir_); } std::vector> MnistDataset::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create @@ -890,31 +913,16 @@ TextFileDataset::TextFileDataset(std::vector dataset_files, int32_t shard_id_(shard_id) {} bool TextFileDataset::ValidateParams() { - if (dataset_files_.empty()) { - MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified."; + if (!ValidateDatasetFilesParam("TextFileDataset", dataset_files_)) { return false; } - for (auto file : dataset_files_) { - std::ifstream handle(file); - if (!handle.is_open()) { - MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file; - return false; - } - } - if (num_samples_ < 0) { MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_; return false; } - if (num_shards_ <= 0) { - MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_; - return false; - } - - if (shard_id_ < 0 || shard_id_ >= num_shards_) { - MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; + if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) { return false; } @@ -960,12 +968,12 @@ std::vector> TextFileDataset::Build() { // Constructor for VOCDataset VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, - const std::map &class_index, bool decode, + const std::map &class_indexing, bool decode, std::shared_ptr sampler) : dataset_dir_(dataset_dir), task_(task), mode_(mode), - class_index_(class_index), + class_index_(class_indexing), decode_(decode), sampler_(sampler) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 051c29ef1b..b193179c1f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -75,7 +75,7 @@ class ZipDataset; /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all", - const std::shared_ptr &sampler = nullptr, const bool &decode = false, + const std::shared_ptr &sampler = nullptr, bool decode = false, const std::set &extensions = {}); /// \brief Function to create a Cifar10 Dataset @@ -84,7 +84,8 @@ std::shared_ptr CelebA(const std::string &dataset_dir, const std: /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar10(const std::string &dataset_dir, std::shared_ptr sampler = nullptr); +std::shared_ptr Cifar10(const std::string &dataset_dir, + const std::shared_ptr &sampler = nullptr); /// \brief Function to create a Cifar100 Dataset /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label'] @@ -93,7 +94,7 @@ std::shared_ptr Cifar10(const std::string &dataset_dir, std::sha /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr Cifar100(const std::string &dataset_dir, - std::shared_ptr sampler = nullptr); + const std::shared_ptr &sampler = nullptr); /// \brief Function to create a CLUEDataset /// \notes The generated dataset has a variable number of columns depending on the task and usage @@ -114,7 +115,8 @@ std::shared_ptr Cifar100(const std::string &dataset_dir, /// \return Shared pointer to the current CLUEDataset std::shared_ptr CLUE(const std::vector &dataset_files, const std::string &task = "AFQMC", const std::string &usage = "train", int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0); + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, + int32_t shard_id = 0); /// \brief Function to create a CocoDataset /// \notes The generated dataset has multi-columns : @@ -147,10 +149,10 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::str /// \param[in] extensions File extensions to be read /// \param[in] class_indexing a class name to label map /// \return Shared pointer to the current ImageFolderDataset -std::shared_ptr ImageFolder(std::string dataset_dir, bool decode = false, - std::shared_ptr sampler = nullptr, - std::set extensions = {}, - std::map class_indexing = {}); +std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode = false, + const std::shared_ptr &sampler = nullptr, + const std::set &extensions = {}, + const std::map &class_indexing = {}); /// \brief Function to create a MnistDataset /// \notes The generated dataset has two columns ['image', 'label'] @@ -158,7 +160,8 @@ std::shared_ptr ImageFolder(std::string dataset_dir, bool de /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, /// A `RandomSampler` will be used to randomly iterate the entire dataset /// \return Shared pointer to the current MnistDataset -std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler = nullptr); +std::shared_ptr Mnist(const std::string &dataset_dir, + const std::shared_ptr &sampler = nullptr); /// \brief Function to create a ConcatDataset /// \notes Reload "+" operator to concat two datasets @@ -183,7 +186,7 @@ std::shared_ptr operator+(const std::shared_ptr &dataset /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) /// \return Shared pointer to the current TextFileDataset -std::shared_ptr TextFile(std::vector dataset_files, int32_t num_samples = 0, +std::shared_ptr TextFile(const std::vector &dataset_files, int32_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); @@ -202,8 +205,8 @@ std::shared_ptr TextFile(std::vector dataset_files /// \return Shared pointer to the current Dataset std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", const std::string &mode = "train", - const std::map &class_index = {}, bool decode = false, - std::shared_ptr sampler = nullptr); + const std::map &class_indexing = {}, bool decode = false, + const std::shared_ptr &sampler = nullptr); /// \brief Function to create a ZipDataset /// \notes Applies zip to the dataset @@ -417,7 +420,7 @@ class CLUEDataset : public Dataset { public: /// \brief Constructor CLUEDataset(const std::vector dataset_files, std::string task, std::string usage, int64_t num_samples, - ShuffleMode shuffle, int num_shards, int shard_id); + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); /// \brief Destructor ~CLUEDataset() = default; @@ -440,8 +443,8 @@ class CLUEDataset : public Dataset { std::string usage_; int64_t num_samples_; ShuffleMode shuffle_; - int num_shards_; - int shard_id_; + int32_t num_shards_; + int32_t shard_id_; }; class CocoDataset : public Dataset { @@ -549,7 +552,7 @@ class VOCDataset : public Dataset { public: /// \brief Constructor VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, - const std::map &class_index, bool decode, std::shared_ptr sampler); + const std::map &class_indexing, bool decode, std::shared_ptr sampler); /// \brief Destructor ~VOCDataset() = default; diff --git a/tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc b/tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc index 6e27a14e22..ba5909d15e 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc @@ -111,7 +111,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail3) { // Attempt to create a TextFile Dataset // with non-existent dataset_files input - std::shared_ptr ds = TextFile({"notexist.txt"}, 0, ShuffleMode::kFalse); + std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; + std::shared_ptr ds = TextFile({tf_file1, "notexist.txt"}, 0, ShuffleMode::kFalse); // Expect failure: specified dataset_files does not exist EXPECT_EQ(ds, nullptr);