diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 791de6c60b..32825fd9df 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -785,6 +785,8 @@ vector ShardReader::GetAllColumns() { MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, const std::vector> &operators) { + vector columns = GetAllColumns(); + CheckIfColumnInIndex(columns); for (const auto &rg : row_group_summary) { auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index d1471f47cb..4056825ff3 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -143,6 +143,7 @@ class FileWriter: ParamTypeError: If index field is invalid. MRMDefineIndexError: If index field is not primitive type. MRMAddIndexError: If failed to add index field. + MRMGetMetaError: If the schema is not set or get meta failed. """ if not index_fields or not isinstance(index_fields, list): raise ParamTypeError('index_fields', 'list') diff --git a/mindspore/mindrecord/tools/cifar100_to_mr.py b/mindspore/mindrecord/tools/cifar100_to_mr.py index a359de853d..c011c8f4b0 100644 --- a/mindspore/mindrecord/tools/cifar100_to_mr.py +++ b/mindspore/mindrecord/tools/cifar100_to_mr.py @@ -24,7 +24,7 @@ from mindspore import log as logger from .cifar100 import Cifar100 from ..common.exceptions import PathNotExistsError from ..filewriter import FileWriter -from ..shardutils import check_filename +from ..shardutils import check_filename, SUCCESS try: cv2 = import_module("cv2") except ModuleNotFoundError: @@ -98,8 +98,11 @@ class Cifar100ToMR: data_list = _construct_raw_data(images, fine_labels, coarse_labels) test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels) - _generate_mindrecord(self.destination, data_list, fields, "img_train") - _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") + if _generate_mindrecord(self.destination, data_list, fields, "img_train") != SUCCESS: + return FAILED + if _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") != SUCCESS: + return FAILED + return SUCCESS def _construct_raw_data(images, fine_labels, coarse_labels): """ diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index da22f5c3b7..460a728b5c 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -47,7 +47,9 @@ def add_and_remove_cv_file(): os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None writer = FileWriter(CV_FILE_NAME, FILES_NUM) data = get_data(CV_DIR_NAME) - cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, "data": {"type": "bytes"}} writer.add_schema(cv_schema_json, "img_schema") writer.add_index(["file_name", "label"]) @@ -226,6 +228,24 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): num_iter += 1 assert num_iter == 20 +def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["id", "data", "label"] + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, + block_reader=True) + assert data_set.get_dataset_size() == 10 + repeat_num = 2 + data_set = data_set.repeat(repeat_num) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) + logger.info("-------------- item[id]: {} ----------------------------".format(item["id"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + num_iter += 1 + assert num_iter == 20 + def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): """tutorial for cv minderdataset.""" @@ -359,13 +379,14 @@ def get_data(dir_name): lines = file_reader.readlines() data_list = [] - for line in lines: + for i, line in enumerate(lines): try: filename, label = line.split(",") label = label.strip("\n") with open(os.path.join(img_dir, filename), "rb") as file_reader: img = file_reader.read() - data_json = {"file_name": filename, + data_json = {"id": i, + "file_name": filename, "data": img, "label": int(label)} data_list.append(data_json) diff --git a/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py b/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py index b3a8d94589..e95f25aae4 100644 --- a/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py @@ -18,6 +18,7 @@ import pytest from mindspore.mindrecord import Cifar100ToMR from mindspore.mindrecord import FileReader from mindspore.mindrecord import MRMOpenError +from mindspore.mindrecord import SUCCESS from mindspore import log as logger CIFAR100_DIR = "../data/mindrecord/testCifar100Data" @@ -26,7 +27,8 @@ MINDRECORD_FILE = "./cifar100.mindrecord" def test_cifar100_to_mindrecord_without_index_fields(): """test transform cifar100 dataset to mindrecord without index fields.""" cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) - cifar100_transformer.transform() + ret = cifar100_transformer.transform() + assert ret == SUCCESS, "Failed to tranform from cifar100 to mindrecord" assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE + "_test") read() diff --git a/tests/ut/python/mindrecord/test_mindrecord_exception.py b/tests/ut/python/mindrecord/test_mindrecord_exception.py index 0a51fbf4e7..1f7a3f859d 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_exception.py +++ b/tests/ut/python/mindrecord/test_mindrecord_exception.py @@ -16,7 +16,7 @@ import os import pytest from mindspore.mindrecord import FileWriter, FileReader, MindPage -from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError +from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError from mindspore import log as logger from utils import get_data @@ -280,3 +280,9 @@ def test_cv_file_writer_shard_num_greater_than_1000(): with pytest.raises(ParamValueError) as err: FileWriter(CV_FILE_NAME, 1001) assert 'Shard number should between' in str(err.value) + +def test_add_index_without_add_schema(): + with pytest.raises(MRMGetMetaError) as err: + fw = FileWriter(CV_FILE_NAME) + fw.add_index(["label"]) + assert 'Failed to get meta info' in str(err.value)