diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index ea2e8352da..dedee8e9b3 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) { .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); (void)py::class_>( *m, "MindrecordPkSampler") - .def(py::init([](int64_t kVal, bool shuffle) { + .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { if (shuffle == true) { - return std::make_shared("label", kVal, std::numeric_limits::max(), + return std::make_shared(kColumn, kVal, std::numeric_limits::max(), GetSeed()); } else { - return std::make_shared("label", kVal); + return std::make_shared(kColumn, kVal); } })); diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index dd34615f7e..4cbb2b3767 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, } MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { + if (column_schema_id_.find(category_field) == column_schema_id_.end()) { + MS_LOG(ERROR) << "Field " << category_field << " does not exist."; + return FAILED; + } auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); if (SUCCESS != ret.first) { return FAILED; @@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri for (auto &field : index_fields) { map_schema_id_fields[field.second] = field.first; } + + if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) { + MS_LOG(ERROR) << "Field " << category_field << " does not exist."; + return -1; + } auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); if (SUCCESS != ret.first) { diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index 80816e7a79..2a9c2c0966 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (dataset_size == 0) return dataset_size; - if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { + if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { return std::min(num_categories_, num_classes) * num_elements_; } return -1; diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 82759989cb..ce732d28a7 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler): num_val (int): Number of elements to sample for each class. num_class (int, optional): Number of classes to sample (default=None, all classes). shuffle (bool, optional): If true, the class IDs are shuffled (default=False). + class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset. Examples: >>> import mindspore.dataset as ds @@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler): ValueError: If shuffle is not boolean. """ - def __init__(self, num_val, num_class=None, shuffle=False): + def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'): if num_val <= 0: raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) @@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler): self.num_val = num_val self.shuffle = shuffle + self.class_column = class_column # work for minddataset def create(self): return cde.PKSampler(self.num_val, self.shuffle) def _create_for_minddataset(self): - return cde.MindrecordPkSampler(self.num_val, self.shuffle) + if not self.class_column or not isinstance(self.class_column, str): + raise ValueError("class_column should be a not empty string value, \ + but got class_column={}".format(class_column)) + return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) class RandomSampler(BuiltinSampler): """ diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 70add46b68..e1d54fa7c8 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -82,3 +82,18 @@ def test_minddataset_lack_db(): num_iter += 1 assert num_iter == 0 os.remove(CV_FILE_NAME) + + +def test_cv_minddataset_pk_sample_error_class_column(): + create_cv_mindrecord(1) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(5, None, True, 'no_exsit_column') + with pytest.raises(Exception, match="MindRecordOp launch failed"): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) +