diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 804613e40a..27854bdf07 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, } MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { - if (column_schema_id_.find(category_field) == column_schema_id_.end()) { - MS_LOG(ERROR) << "Field " << category_field << " does not exist."; + std::map index_columns; + for (auto &field : get_shard_header()->get_fields()) { + index_columns[field.second] = field.first; + } + if (index_columns.find(category_field) == index_columns.end()) { + MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; return FAILED; } - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); if (SUCCESS != ret.first) { return FAILED; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3f87127e26..06b740bb6b 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset): if block_reader is True and sampler is not None: raise ValueError("block reader not allowed true when use sampler") - if shuffle is True and sampler is not None: - raise ValueError("shuffle not allowed true when use sampler") + if shuffle is not None and sampler is not None: + raise ValueError("shuffle not allowed when use sampler") if block_reader is False and sampler is None: self.global_shuffle = not bool(shuffle is False) diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index e1d54fa7c8..2a269ffc80 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column(): os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) +def test_cv_minddataset_pk_sample_exclusive_shuffle(): + create_cv_mindrecord(1) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(2) + with pytest.raises(Exception, match="shuffle not allowed when use sampler"): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, + sampler=sampler, shuffle=False) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 584bb88041..5656a08ae4 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -60,7 +60,21 @@ def add_and_remove_cv_file(): os.remove("{}".format(x)) os.remove("{}.db".format(x)) +def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + num_readers = 4 + sampler = ds.PKSampler(2) + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"]