diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 669f0397d8..d008912738 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -14,10 +14,10 @@ * limitations under the License. */ +#include "minddata/dataset/include/datasets.h" #include #include #include -#include "minddata/dataset/include/datasets.h" #include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/transforms.h" // Source dataset headers (in alphabetical order) @@ -696,7 +696,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data return Status::OK(); } -// Helper function to validate dataset dataset files parameter +// Helper function to validate dataset files parameter Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector &dataset_files) { if (dataset_files.empty()) { std::string err_msg = dataset_name + ": dataset_files is not specified."; @@ -743,7 +743,6 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s // Helper function to validate dataset sampler parameter Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr &sampler) { if (sampler == nullptr) { - MS_LOG(ERROR) << dataset_name << ": Sampler is not constructed correctly, sampler: nullptr"; std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); @@ -751,12 +750,13 @@ Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared return Status::OK(); } -Status ValidateStringValue(const std::string &str, const std::unordered_set &valid_strings) { +Status ValidateStringValue(const std::string &dataset_name, const std::string &str, + const std::unordered_set &valid_strings) { if (valid_strings.find(str) == valid_strings.end()) { std::string mode; mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); - std::string err_msg = str + " does not match any mode in [" + mode + " ]"; + std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -842,7 +842,7 @@ Status CelebANode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"})); + RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"})); return Status::OK(); } @@ -873,7 +873,7 @@ Status Cifar10Node::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"})); return Status::OK(); } @@ -906,7 +906,7 @@ Status Cifar100Node::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"})); return Status::OK(); } @@ -945,20 +945,9 @@ CLUENode::CLUENode(const std::vector clue_files, std::string task, Status CLUENode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); - std::vector task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; - std::vector usage_list = {"train", "test", "eval"}; + RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); - if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) { - std::string err_msg = "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { - std::string err_msg = "usage should be train, test or eval."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } + RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"})); if (num_samples_ < 0) { std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_); @@ -1133,18 +1122,12 @@ Status CocoNode::ValidateParams() { Path annotation_file(annotation_file_); if (!annotation_file.Exists()) { - std::string err_msg = "annotation_file is invalid or not exist"; + std::string err_msg = "CocoNode: annotation_file is invalid or does not exist."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - std::set task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"}; - auto task_iter = task_list.find(task_); - if (task_iter == task_list.end()) { - std::string err_msg = "Invalid task type"; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } + RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"})); return Status::OK(); } @@ -1348,7 +1331,7 @@ Status ManifestNode::ValidateParams() { for (char c : dataset_file_) { auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); if (p != forbidden_symbols.end()) { - std::string err_msg = "filename should not contains :*?\"<>|`&;\'"; + std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -1356,19 +1339,14 @@ Status ManifestNode::ValidateParams() { Path manifest_file(dataset_file_); if (!manifest_file.Exists()) { - std::string err_msg = "dataset file: [" + dataset_file_ + "] is invalid or not exist"; + std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_)); - std::vector usage_list = {"train", "eval", "inference"}; - if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { - std::string err_msg = "usage should be train, eval or inference."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } + RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"})); return Status::OK(); } @@ -1536,7 +1514,7 @@ Status MnistNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"})); return Status::OK(); } @@ -1753,35 +1731,32 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const Status VOCNode::ValidateParams() { Path dir(dataset_dir_); - if (!dir.IsDirectory()) { - std::string err_msg = "Invalid dataset path or no dataset path is specified."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } + + RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_)); if (task_ == "Segmentation") { if (!class_index_.empty()) { - std::string err_msg = "class_indexing is invalid in Segmentation task."; + std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt"; if (!imagesets_file.Exists()) { - std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; - MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; + std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; + MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; RETURN_STATUS_SYNTAX_ERROR(err_msg); } } else if (task_ == "Detection") { Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt"; if (!imagesets_file.Exists()) { - std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist"; - MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; + std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; + MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; RETURN_STATUS_SYNTAX_ERROR(err_msg); } } else { - std::string err_msg = "Invalid task: " + task_; + std::string err_msg = "VOCNode: Invalid task: " + task_; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -1859,15 +1834,17 @@ std::vector> BatchNode::Build() { Status BatchNode::ValidateParams() { if (batch_size_ <= 0) { - std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_); + std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (!cols_to_map_.empty()) { - std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty."; + std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + return Status::OK(); } @@ -1906,28 +1883,29 @@ std::vector> BucketBatchByLengthNode::Build() { Status BucketBatchByLengthNode::ValidateParams() { if (element_length_function_ == nullptr && column_names_.size() != 1) { - std::string err_msg = - "BucketBatchByLength: element_length_function not specified, but not one column name: " + column_names_.size(); + std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " + + column_names_.size(); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } // Check bucket_boundaries: must be positive and strictly increasing if (bucket_boundaries_.empty()) { - std::string err_msg = "BucketBatchByLength: bucket_boundaries cannot be empty."; + std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + for (int i = 0; i < bucket_boundaries_.size(); i++) { if (bucket_boundaries_[i] <= 0) { - std::string err_msg = "BucketBatchByLength: Invalid non-positive bucket_boundaries, index: "; + std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: "; MS_LOG(ERROR) << "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: " << i << " was: " << bucket_boundaries_[i]; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) { - std::string err_msg = "BucketBatchByLength: Invalid bucket_boundaries not be strictly increasing."; + std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing."; MS_LOG(ERROR) << "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: " << i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i] @@ -1938,20 +1916,24 @@ Status BucketBatchByLengthNode::ValidateParams() { // Check bucket_batch_sizes: must be positive if (bucket_batch_sizes_.empty()) { - std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must be non-empty"; + std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) { - std::string err_msg = "BucketBatchByLength: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; + std::string err_msg = + "BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) { - std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must only contain positive numbers."; + std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + return Status::OK(); } @@ -1981,26 +1963,26 @@ std::vector> BuildVocabNode::Build() { Status BuildVocabNode::ValidateParams() { if (vocab_ == nullptr) { - std::string err_msg = "BuildVocab: vocab is null."; + std::string err_msg = "BuildVocabNode: vocab is null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (top_k_ <= 0) { - std::string err_msg = "BuildVocab: top_k should be positive, but got: " + top_k_; + std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) { - std::string err_msg = "BuildVocab: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; - MS_LOG(ERROR) << "BuildVocab: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " + std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; + MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " << "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (!columns_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocab", "columns", columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_)); } return Status::OK(); @@ -2014,15 +1996,17 @@ ConcatNode::ConcatNode(const std::vector> &datasets) : Status ConcatNode::ValidateParams() { if (datasets_.empty()) { - std::string err_msg = "Concat: concatenated datasets are not specified."; + std::string err_msg = "ConcatNode: concatenated datasets are not specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { - std::string err_msg = "Concat: concatenated datasets should not be null."; + std::string err_msg = "ConcatNode: concatenated datasets should not be null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + return Status::OK(); } @@ -2070,7 +2054,7 @@ std::vector> MapNode::Build() { Status MapNode::ValidateParams() { if (operations_.empty()) { - std::string err_msg = "Map: No operation is specified."; + std::string err_msg = "MapNode: No operation is specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -2158,8 +2142,8 @@ std::vector> RepeatNode::Build() { Status RepeatNode::ValidateParams() { if (repeat_count_ <= 0 && repeat_count_ != -1) { - std::string err_msg = - "Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_); + std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + + std::to_string(repeat_count_); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -2211,10 +2195,11 @@ std::vector> SkipNode::Build() { // Function to validate the parameters for SkipNode Status SkipNode::ValidateParams() { if (skip_count_ <= -1) { - std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); + std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + return Status::OK(); } @@ -2236,7 +2221,7 @@ std::vector> TakeNode::Build() { Status TakeNode::ValidateParams() { if (take_count_ <= 0 && take_count_ != -1) { std::string err_msg = - "Take: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); + "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -2252,15 +2237,17 @@ ZipNode::ZipNode(const std::vector> &datasets) : datase Status ZipNode::ValidateParams() { if (datasets_.empty()) { - std::string err_msg = "Zip: datasets to zip are not specified."; + std::string err_msg = "ZipNode: datasets to zip are not specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { std::string err_msg = "ZipNode: zip datasets should not be null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } + return Status::OK(); }