diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 41cf704a69..aad9cad4f6 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -14,14 +14,14 @@ * limitations under the License. */ +#include #include #include -#include #include "minddata/dataset/include/datasets.h" #include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/transforms.h" -#include "minddata/dataset/engine/dataset_iterator.h" // Source dataset headers (in alphabetical order) +#include "minddata/dataset/engine/dataset_iterator.h" #include "minddata/dataset/engine/datasetops/source/album_op.h" #include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h" @@ -56,13 +56,13 @@ #include "minddata/dataset/engine/datasetops/zip_op.h" // Sampler headers (in alphabetical order) -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/random.h" namespace mindspore { namespace dataset { @@ -671,98 +671,112 @@ Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, in } // Helper function to validate dataset directory parameter -bool ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { +Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { if (dataset_dir.empty()) { - MS_LOG(ERROR) << dataset_name << ": dataset_dir is not specified."; - return false; + std::string err_msg = dataset_name + ": dataset_dir is not specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } Path dir(dataset_dir); if (!dir.IsDirectory()) { - MS_LOG(ERROR) << dataset_name << ": dataset_dir: [" << dataset_dir << "] is an invalid directory path."; - return false; + std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (access(dataset_dir.c_str(), R_OK) == -1) { - MS_LOG(ERROR) << dataset_name << ": No access to specified dataset path: " << dataset_dir; - return false; + std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } // Helper function to validate dataset dataset files parameter -bool ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector &dataset_files) { +Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector &dataset_files) { if (dataset_files.empty()) { - MS_LOG(ERROR) << dataset_name << ": dataset_files is not specified."; - return false; + std::string err_msg = dataset_name + ": dataset_files is not specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } for (auto f : dataset_files) { Path dataset_file(f); if (!dataset_file.Exists()) { - MS_LOG(ERROR) << dataset_name << ": dataset file: [" << f << "] is invalid or does not exist."; - return false; + std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } } - return true; + return Status::OK(); } // Helper function to validate dataset num_shards and shard_id parameters -bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { +Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { if (num_shards <= 0) { - MS_LOG(ERROR) << dataset_name << ": Invalid num_shards: " << num_shards; - return false; + std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (shard_id < 0 || shard_id >= num_shards) { - MS_LOG(ERROR) << dataset_name << ": Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; - return false; + // num_shards; + std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + + ", num_shards: " + std::to_string(num_shards); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } // Helper function to validate dataset sampler parameter -bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr &sampler) { +Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr &sampler) { if (sampler == nullptr) { MS_LOG(ERROR) << dataset_name << ": Sampler is not constructed correctly, sampler: nullptr"; - return false; + std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } -bool ValidateStringValue(const std::string &str, const std::unordered_set &valid_strings) { +Status ValidateStringValue(const std::string &str, const std::unordered_set &valid_strings) { if (valid_strings.find(str) == valid_strings.end()) { std::string mode; mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); - MS_LOG(ERROR) << str << " does not match any mode in [" + mode + " ]"; - return false; + std::string err_msg = str + " does not match any mode in [" + mode + " ]"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } // Helper function to validate dataset input/output column parameter -bool ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, - const std::vector &columns) { +Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, + const std::vector &columns) { if (columns.empty()) { - MS_LOG(ERROR) << dataset_name << ":" << column_param << " should not be empty string"; - return false; + std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } for (uint32_t i = 0; i < columns.size(); ++i) { if (columns[i].empty()) { - MS_LOG(ERROR) << dataset_name << ":" << column_param << "[" << i << "] should not be empty"; - return false; + std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } } std::set columns_set(columns.begin(), columns.end()); if (columns_set.size() != columns.size()) { - MS_LOG(ERROR) << dataset_name << ":" << column_param << ": Every column name should not be same with others"; - return false; + // others"; + std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } /* ####################################### Derived Dataset classes ################################# */ @@ -780,26 +794,29 @@ AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &da decode_(decode), sampler_(sampler) {} -bool AlbumDataset::ValidateParams() { - if (!ValidateDatasetDirParam("AlbumDataset", dataset_dir_)) { - return false; - } +Status AlbumDataset::ValidateParams() { + Status rc; - if (!ValidateDatasetFilesParam("AlbumDataset", {schema_path_})) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumDataset", dataset_dir_)); + + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumDataset", {schema_path_})); + if (rc.IsError()) { + return rc; } - if (!ValidateDatasetSampler("AlbumDataset", sampler_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumDataset", sampler_)); + if (rc.IsError()) { + return rc; } if (!column_names_.empty()) { - if (!ValidateDatasetColumnParam("AlbumDataset", "column_names", column_names_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumDataset", "column_names", column_names_)); + if (rc.IsError()) { + return rc; } } - return true; + return Status::OK(); } // Function to build AlbumDataset @@ -824,9 +841,25 @@ CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string & const std::set &extensions) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {} -bool CelebADataset::ValidateParams() { - return ValidateDatasetDirParam("CelebADataset", dataset_dir_) && ValidateDatasetSampler("CelebADataset", sampler_) && - ValidateStringValue(usage_, {"all", "train", "valid", "test"}); +Status CelebADataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebADataset", dataset_dir_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebADataset", sampler_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"})); + if (rc.IsError()) { + return rc; + } + + return Status::OK(); } // Function to build CelebADataset @@ -851,9 +884,25 @@ Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string std::shared_ptr sampler) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} -bool Cifar10Dataset::ValidateParams() { - return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) && - ValidateDatasetSampler("Cifar10Dataset", sampler_) && ValidateStringValue(usage_, {"train", "test", "all"}); +Status Cifar10Dataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Dataset", sampler_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); + if (rc.IsError()) { + return rc; + } + + return Status::OK(); } // Function to build CifarOp for Cifar10 @@ -879,9 +928,25 @@ Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::stri std::shared_ptr sampler) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} -bool Cifar100Dataset::ValidateParams() { - return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_) && - ValidateDatasetSampler("Cifar100Dataset", sampler_) && ValidateStringValue(usage_, {"train", "test", "all"}); +Status Cifar100Dataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Dataset", sampler_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); + if (rc.IsError()) { + return rc; + } + + return Status::OK(); } // Function to build CifarOp for Cifar100 @@ -915,34 +980,41 @@ CLUEDataset::CLUEDataset(const std::vector clue_files, std::string num_shards_(num_shards), shard_id_(shard_id) {} -bool CLUEDataset::ValidateParams() { - if (!ValidateDatasetFilesParam("CLUEDataset", dataset_files_)) { - return false; +Status CLUEDataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_)); + if (rc.IsError()) { + return rc; } std::vector task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; std::vector usage_list = {"train", "test", "eval"}; if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) { - MS_LOG(ERROR) << "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL."; - return false; + std::string err_msg = "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { - MS_LOG(ERROR) << "usage should be train, test or eval."; - return false; + std::string err_msg = "usage should be train, test or eval."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (num_samples_ < 0) { - MS_LOG(ERROR) << "CLUEDataset: Invalid number of samples: " << num_samples_; - return false; + std::string err_msg = "CLUEDataset: Invalid number of samples: " + num_samples_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)); + if (rc.IsError()) { + return rc; } - return true; + return Status::OK(); } // Function to split string based on a character delimiter @@ -1100,25 +1172,35 @@ CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &anno const bool &decode, const std::shared_ptr &sampler) : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} -bool CocoDataset::ValidateParams() { - if (!ValidateDatasetDirParam("CocoDataset", dataset_dir_)) { - return false; +Status CocoDataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoDataset", dataset_dir_)); + if (rc.IsError()) { + return rc; } - if (!ValidateDatasetSampler("CocoDataset", sampler_)) { - return false; + + RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoDataset", sampler_)); + if (rc.IsError()) { + return rc; } + Path annotation_file(annotation_file_); if (!annotation_file.Exists()) { - MS_LOG(ERROR) << "annotation_file is invalid or not exist"; - return false; + std::string err_msg = "annotation_file is invalid or not exist"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } + std::set task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"}; auto task_iter = task_list.find(task_); if (task_iter == task_list.end()) { - MS_LOG(ERROR) << "Invalid task type"; - return false; + std::string err_msg = "Invalid task type"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + + return Status::OK(); } // Function to build CocoDataset @@ -1172,7 +1254,7 @@ std::vector> CocoDataset::Build() { schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); break; default: - MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type"; + MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type: " << task_type; return {}; } std::shared_ptr op = @@ -1196,37 +1278,50 @@ CSVDataset::CSVDataset(const std::vector &csv_files, char field_del num_shards_(num_shards), shard_id_(shard_id) {} -bool CSVDataset::ValidateParams() { - if (!ValidateDatasetFilesParam("CSVDataset", dataset_files_)) { - return false; +Status CSVDataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVDataset", dataset_files_)); + if (rc.IsError()) { + return rc; } if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') { - MS_LOG(ERROR) << "CSVDataset: The field delimiter should not be \", \\r, \\n"; - return false; + std::string err_msg = "CSVDataset: The field delimiter should not be \", \\r, \\n"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (num_samples_ < 0) { - MS_LOG(ERROR) << "CSVDataset: Invalid number of samples: " << num_samples_; - return false; + std::string err_msg = "CSVDataset: Invalid number of samples: " + num_samples_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetShardParams("CSVDataset", num_shards_, shard_id_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVDataset", num_shards_, shard_id_)); + if (rc.IsError()) { + return rc; } if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) { - MS_LOG(ERROR) << "CSVDataset: column_default should not be null."; - return false; + std::string err_msg = "CSVDataset: column_default should not be null."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (!column_names_.empty()) { - if (!ValidateDatasetColumnParam("CSVDataset", "column_names", column_names_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVDataset", "column_names", column_names_)); + if (rc.IsError()) { + return rc; } } - return true; + return Status::OK(); } // Function to build CSVDataset @@ -1286,9 +1381,20 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std class_indexing_(class_indexing), exts_(extensions) {} -bool ImageFolderDataset::ValidateParams() { - return ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_) && - ValidateDatasetSampler("ImageFolderDataset", sampler_); +Status ImageFolderDataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderDataset", sampler_)); + if (rc.IsError()) { + return rc; + } + + return Status::OK(); } std::vector> ImageFolderDataset::Build() { @@ -1315,33 +1421,39 @@ ManifestDataset::ManifestDataset(const std::string &dataset_file, const std::str const std::map &class_indexing, bool decode) : dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {} -bool ManifestDataset::ValidateParams() { +Status ManifestDataset::ValidateParams() { + Status rc; + std::vector forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; for (char c : dataset_file_) { auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); if (p != forbidden_symbols.end()) { - MS_LOG(ERROR) << "filename should not contains :*?\"<>|`&;\'"; - return false; + std::string err_msg = "filename should not contains :*?\"<>|`&;\'"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } } Path manifest_file(dataset_file_); if (!manifest_file.Exists()) { - MS_LOG(ERROR) << "dataset file: [" << dataset_file_ << "] is invalid or not exist"; - return false; + std::string err_msg = "dataset file: [" + dataset_file_ + "] is invalid or not exist"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetSampler("ManifestDataset", sampler_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestDataset", sampler_)); + if (rc.IsError()) { + return rc; } std::vector usage_list = {"train", "eval", "inference"}; if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { - MS_LOG(ERROR) << "usage should be train, eval or inference."; - return false; + std::string err_msg = "usage should be train, eval or inference."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } std::vector> ManifestDataset::Build() { @@ -1368,9 +1480,25 @@ std::vector> ManifestDataset::Build() { MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr sampler) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} -bool MnistDataset::ValidateParams() { - return ValidateStringValue(usage_, {"train", "test", "all"}) && - ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_); +Status MnistDataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistDataset", dataset_dir_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistDataset", sampler_)); + if (rc.IsError()) { + return rc; + } + + RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); + if (rc.IsError()) { + return rc; + } + + return Status::OK(); } std::vector> MnistDataset::Build() { @@ -1390,20 +1518,28 @@ std::vector> MnistDataset::Build() { } // ValideParams for RandomDataset -bool RandomDataset::ValidateParams() { +Status RandomDataset::ValidateParams() { + Status rc; + if (total_rows_ < 0) { - MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows_; - return false; + std::string err_msg = "RandomDataset: total_rows must be greater than or equal 0, now get " + total_rows_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetSampler("RandomDataset", sampler_)) { - return false; + + RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomDataset", sampler_)); + if (rc.IsError()) { + return rc; } + if (!columns_list_.empty()) { - if (!ValidateDatasetColumnParam("RandomDataset", "columns_list", columns_list_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomDataset", "columns_list", columns_list_)); + if (rc.IsError()) { + return rc; } } - return true; + + return Status::OK(); } int32_t RandomDataset::GenRandomInt(int32_t min, int32_t max) { @@ -1466,21 +1602,26 @@ TextFileDataset::TextFileDataset(std::vector dataset_files, int32_t num_shards_(num_shards), shard_id_(shard_id) {} -bool TextFileDataset::ValidateParams() { - if (!ValidateDatasetFilesParam("TextFileDataset", dataset_files_)) { - return false; +Status TextFileDataset::ValidateParams() { + Status rc; + + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_)); + if (rc.IsError()) { + return rc; } if (num_samples_ < 0) { - MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_; - return false; + std::string err_msg = "TextFileDataset: Invalid number of samples: " + num_samples_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_)); + if (rc.IsError()) { + return rc; } - return true; + return Status::OK(); } // Function to build TextFileDataset @@ -1526,7 +1667,7 @@ std::vector> TextFileDataset::Build() { #ifndef ENABLE_ANDROID // Validator for TFRecordDataset -bool TFRecordDataset::ValidateParams() { return true; } +Status TFRecordDataset::ValidateParams() { return Status::OK(); } // Function to build TFRecordDataset std::vector> TFRecordDataset::Build() { @@ -1586,36 +1727,47 @@ VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, decode_(decode), sampler_(sampler) {} -bool VOCDataset::ValidateParams() { +Status VOCDataset::ValidateParams() { + Status rc; + Path dir(dataset_dir_); if (!dir.IsDirectory()) { - MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; - return false; + std::string err_msg = "Invalid dataset path or no dataset path is specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetSampler("VOCDataset", sampler_)) { - return false; + + RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCDataset", sampler_)); + if (rc.IsError()) { + return rc; } + if (task_ == "Segmentation") { if (!class_index_.empty()) { - MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task."; - return false; + std::string err_msg = "class_indexing is invalid in Segmentation task."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt"; if (!imagesets_file.Exists()) { - MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; - return false; + std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; + MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } } else if (task_ == "Detection") { Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt"; if (!imagesets_file.Exists()) { - MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; - return false; + std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; + MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } } else { - MS_LOG(ERROR) << "Invalid task: " << task_; - return false; + std::string err_msg = "Invalid task: " + task_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + + return Status::OK(); } // Function to build VOCDataset @@ -1683,16 +1835,18 @@ std::vector> BatchDataset::Build() { return node_ops; } -bool BatchDataset::ValidateParams() { +Status BatchDataset::ValidateParams() { if (batch_size_ <= 0) { - MS_LOG(ERROR) << "Batch: batch_size should be positive integer, but got: " << batch_size_; - return false; + std::string err_msg = "Batch: batch_size should be positive integer, but got: " + batch_size_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (!cols_to_map_.empty()) { - MS_LOG(ERROR) << "cols_to_map functionality is not implemented in C++; this should be left empty."; - return false; + std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } #ifndef ENABLE_ANDROID @@ -1725,48 +1879,55 @@ std::vector> BucketBatchByLengthDataset::Build() { return node_ops; } -bool BucketBatchByLengthDataset::ValidateParams() { +Status BucketBatchByLengthDataset::ValidateParams() { if (element_length_function_ == nullptr && column_names_.size() != 1) { - MS_LOG(ERROR) << "BucketBatchByLength: If element_length_function is not specified, exactly one column name " - "should be passed."; - return false; + std::string err_msg = + "BucketBatchByLength: element_length_function not specified, but not one column name: " + column_names_.size(); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } // Check bucket_boundaries: must be positive and strictly increasing if (bucket_boundaries_.empty()) { - MS_LOG(ERROR) << "BucketBatchByLength: bucket_boundaries cannot be empty."; - return false; + std::string err_msg = "BucketBatchByLength: bucket_boundaries cannot be empty."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } for (int i = 0; i < bucket_boundaries_.size(); i++) { if (bucket_boundaries_[i] <= 0) { + std::string err_msg = "BucketBatchByLength: Invalid non-positive bucket_boundaries, index: "; MS_LOG(ERROR) << "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: " << i << " was: " << bucket_boundaries_[i]; - return false; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) { + std::string err_msg = "BucketBatchByLength: Invalid bucket_boundaries not be strictly increasing."; MS_LOG(ERROR) << "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: " << i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i] << " respectively."; - return false; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } } // Check bucket_batch_sizes: must be positive if (bucket_batch_sizes_.empty()) { - MS_LOG(ERROR) << "BucketBatchByLength: bucket_batch_sizes must be non-empty"; - return false; + std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must be non-empty"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) { - MS_LOG(ERROR) << "BucketBatchByLength: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; - return false; + std::string err_msg = "BucketBatchByLength: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) { - MS_LOG(ERROR) << "BucketBatchByLength: bucket_batch_sizes must only contain positive numbers."; - return false; + std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must only contain positive numbers."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } BuildVocabDataset::BuildVocabDataset(std::shared_ptr vocab, const std::vector &columns, @@ -1791,26 +1952,31 @@ std::vector> BuildVocabDataset::Build() { return node_ops; } -bool BuildVocabDataset::ValidateParams() { +Status BuildVocabDataset::ValidateParams() { + Status rc; if (vocab_ == nullptr) { - MS_LOG(ERROR) << "BuildVocab: vocab is null."; - return false; + std::string err_msg = "BuildVocab: vocab is null."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (top_k_ <= 0) { - MS_LOG(ERROR) << "BuildVocab: top_k shoule be positive, but got: " << top_k_; - return false; + std::string err_msg = "BuildVocab: top_k should be positive, but got: " + top_k_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) { - MS_LOG(ERROR) << "BuildVocab: requency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " + std::string err_msg = "BuildVocab: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; + MS_LOG(ERROR) << "BuildVocab: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " << "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; - return false; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (!columns_.empty()) { - if (!ValidateDatasetColumnParam("BuildVocab", "columns", columns_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocab", "columns", columns_)); + if (rc.IsError()) { + return rc; } } - return true; + return Status::OK(); } #endif @@ -1819,16 +1985,18 @@ ConcatDataset::ConcatDataset(const std::vector> &datase this->children = datasets_; } -bool ConcatDataset::ValidateParams() { +Status ConcatDataset::ValidateParams() { if (datasets_.empty()) { - MS_LOG(ERROR) << "Concat: concatenated datasets are not specified."; - return false; + std::string err_msg = "Concat: concatenated datasets are not specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { - MS_LOG(ERROR) << "Concat: concatenated dataset should not be null."; - return false; + std::string err_msg = "Concat: concatenated datasets should not be null."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } std::vector> ConcatDataset::Build() { @@ -1870,41 +2038,54 @@ std::vector> MapDataset::Build() { return node_ops; } -bool MapDataset::ValidateParams() { +Status MapDataset::ValidateParams() { + Status rc; + if (operations_.empty()) { - MS_LOG(ERROR) << "Map: No operation is specified."; - return false; + std::string err_msg = "Map: No operation is specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (!input_columns_.empty()) { - if (!ValidateDatasetColumnParam("MapDataset", "input_columns", input_columns_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "input_columns", input_columns_)); + if (rc.IsError()) { + return rc; } } if (!output_columns_.empty()) { - if (!ValidateDatasetColumnParam("MapDataset", "output_columns", output_columns_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "output_columns", output_columns_)); + if (rc.IsError()) { + return rc; } } if (!project_columns_.empty()) { - if (!ValidateDatasetColumnParam("MapDataset", "project_columns", project_columns_)) { - return false; + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "project_columns", project_columns_)); + if (rc.IsError()) { + return rc; } } - return true; + return Status::OK(); } // Function to build ProjectOp ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} -bool ProjectDataset::ValidateParams() { +Status ProjectDataset::ValidateParams() { + Status rc; + if (columns_.empty()) { - MS_LOG(ERROR) << "ProjectDataset: No columns are specified."; - return false; + std::string err_msg = "ProjectDataset: No columns are specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetColumnParam("ProjectDataset", "columns", columns_)) { - return false; + + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectDataset", "columns", columns_)); + if (rc.IsError()) { + return rc; } - return true; + + return Status::OK(); } std::vector> ProjectDataset::Build() { @@ -1920,16 +2101,26 @@ RenameDataset::RenameDataset(const std::vector &input_columns, const std::vector &output_columns) : input_columns_(input_columns), output_columns_(output_columns) {} -bool RenameDataset::ValidateParams() { +Status RenameDataset::ValidateParams() { + Status rc; + if (input_columns_.size() != output_columns_.size()) { - MS_LOG(ERROR) << "RenameDataset: input and output columns must be the same size"; - return false; + std::string err_msg = "RenameDataset: input and output columns must be the same size"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!ValidateDatasetColumnParam("RenameDataset", "input_columns", input_columns_) || - !ValidateDatasetColumnParam("RenameDataset", "output_columns", output_columns_)) { - return false; + + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameDataset", "input_columns", input_columns_)); + if (rc.IsError()) { + return rc; } - return true; + + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameDataset", "output_columns", output_columns_)); + if (rc.IsError()) { + return rc; + } + + return Status::OK(); } std::vector> RenameDataset::Build() { @@ -1950,13 +2141,15 @@ std::vector> RepeatDataset::Build() { return node_ops; } -bool RepeatDataset::ValidateParams() { +Status RepeatDataset::ValidateParams() { if (repeat_count_ <= 0 && repeat_count_ != -1) { - MS_LOG(ERROR) << "Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " << repeat_count_; - return false; + std::string err_msg = + "Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + repeat_count_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } // Constructor for ShuffleDataset @@ -1974,13 +2167,14 @@ std::vector> ShuffleDataset::Build() { } // Function to validate the parameters for ShuffleDataset -bool ShuffleDataset::ValidateParams() { +Status ShuffleDataset::ValidateParams() { if (shuffle_size_ <= 1) { - MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_; - return false; + std::string err_msg = "ShuffleDataset: Invalid input, shuffle_size: " + shuffle_size_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } // Constructor for SkipDataset @@ -1996,13 +2190,13 @@ std::vector> SkipDataset::Build() { } // Function to validate the parameters for SkipDataset -bool SkipDataset::ValidateParams() { +Status SkipDataset::ValidateParams() { if (skip_count_ <= -1) { - MS_LOG(ERROR) << "Skip: skip_count should not be negative, skip_count: " << skip_count_; - return false; + std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + skip_count_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - - return true; + return Status::OK(); } // Constructor for TakeDataset @@ -2018,13 +2212,13 @@ std::vector> TakeDataset::Build() { } // Function to validate the parameters for TakeDataset -bool TakeDataset::ValidateParams() { +Status TakeDataset::ValidateParams() { if (take_count_ <= 0 && take_count_ != -1) { - MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_; - return false; + std::string err_msg = "Take: take_count should be either -1 or positive integer, take_count: " + take_count_; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - - return true; + return Status::OK(); } // Function to build ZipOp @@ -2034,16 +2228,18 @@ ZipDataset::ZipDataset(const std::vector> &datasets) : } } -bool ZipDataset::ValidateParams() { +Status ZipDataset::ValidateParams() { if (datasets_.empty()) { - MS_LOG(ERROR) << "Zip: dataset to zip are not specified."; - return false; + std::string err_msg = "Zip: datasets to zip are not specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { - MS_LOG(ERROR) << "ZipDataset: zip dataset should not be null."; - return false; + std::string err_msg = "ZipDataset: zip datasets should not be null."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); } - return true; + return Status::OK(); } std::vector> ZipDataset::Build() { diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 62bf9668d6..80c8f076b6 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -18,17 +18,17 @@ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ #include -#include +#include #include #include -#include -#include #include +#include +#include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/data_schema.h" -#include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/type_id.h" #include "minddata/dataset/kernels/c_func_op.h" #include "minddata/dataset/kernels/tensor_op.h" @@ -442,8 +442,8 @@ class Dataset : public std::enable_shared_from_this { virtual std::vector> Build() = 0; /// \brief Pure virtual function for derived class to implement parameters validation - /// \return bool true if all the parameters are valid - virtual bool ValidateParams() = 0; + /// \return Status Status::OK() if all the parameters are valid + virtual Status ValidateParams() = 0; /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator @@ -692,8 +692,8 @@ class AlbumDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -717,8 +717,8 @@ class CelebADataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -743,8 +743,8 @@ class Cifar10Dataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -765,8 +765,8 @@ class Cifar100Dataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -790,8 +790,8 @@ class CLUEDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: /// \brief Split string based on a character delimiter @@ -821,8 +821,8 @@ class CocoDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -869,8 +869,8 @@ class CSVDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector dataset_files_; @@ -899,8 +899,8 @@ class ImageFolderDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -926,8 +926,8 @@ class ManifestDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_file_; @@ -951,8 +951,8 @@ class MnistDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::string dataset_dir_; @@ -989,8 +989,8 @@ class RandomDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: /// \brief A quick inline for producing a random number between (and including) min/max @@ -1023,8 +1023,8 @@ class TextFileDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector dataset_files_; @@ -1074,8 +1074,8 @@ class TFRecordDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector dataset_files_; @@ -1104,8 +1104,8 @@ class VOCDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: const std::string kColumnImage = "image"; @@ -1140,8 +1140,8 @@ class BatchDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: int32_t batch_size_; @@ -1170,8 +1170,8 @@ class BucketBatchByLengthDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector column_names_; @@ -1198,8 +1198,8 @@ class BuildVocabDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::shared_ptr vocab_; @@ -1224,8 +1224,8 @@ class ConcatDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector> datasets_; @@ -1245,8 +1245,8 @@ class MapDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector> operations_; @@ -1268,8 +1268,8 @@ class ProjectDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector columns_; @@ -1288,8 +1288,8 @@ class RenameDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector input_columns_; @@ -1309,8 +1309,8 @@ class RepeatDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: int32_t repeat_count_; @@ -1324,7 +1324,7 @@ class ShuffleDataset : public Dataset { std::vector> Build() override; - bool ValidateParams() override; + Status ValidateParams() override; private: int32_t shuffle_size_; @@ -1345,8 +1345,8 @@ class SkipDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: int32_t skip_count_; @@ -1365,8 +1365,8 @@ class TakeDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: int32_t take_count_; @@ -1385,8 +1385,8 @@ class ZipDataset : public Dataset { std::vector> Build() override; /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; private: std::vector> datasets_; diff --git a/mindspore/ccsrc/minddata/dataset/include/status.h b/mindspore/ccsrc/minddata/dataset/include/status.h index 9f45e12c68..bc63de9870 100644 --- a/mindspore/ccsrc/minddata/dataset/include/status.h +++ b/mindspore/ccsrc/minddata/dataset/include/status.h @@ -66,6 +66,11 @@ namespace dataset { } \ } while (false) +#define RETURN_STATUS_SYNTAX_ERROR(_e) \ + do { \ + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ + } while (false) + enum class StatusCode : char { kOK = 0, kOutOfMemory = 1, diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index 9f45e12c68..bc63de9870 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -66,6 +66,11 @@ namespace dataset { } \ } while (false) +#define RETURN_STATUS_SYNTAX_ERROR(_e) \ + do { \ + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ + } while (false) + enum class StatusCode : char { kOK = 0, kOutOfMemory = 1,