|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""test mindrecord base"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
|
|
|
|
import uuid
|
|
|
|
|
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
|
|
|
|
@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord"
|
|
|
|
|
CV3_FILE_NAME = "./imagenet_append.mindrecord"
|
|
|
|
|
NLP_FILE_NAME = "./aclImdb.mindrecord"
|
|
|
|
|
|
|
|
|
|
def test_write_read_process():
|
|
|
|
|
mindrecord_file_name = "test.mindrecord"
|
|
|
|
|
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes abc", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes def", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes ghi", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes jkl", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes mno", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes pqr", encoding='UTF-8')}
|
|
|
|
|
]
|
|
|
|
|
writer = FileWriter(mindrecord_file_name)
|
|
|
|
|
schema = {"file_name": {"type": "string"},
|
|
|
|
|
"label": {"type": "int32"},
|
|
|
|
|
"score": {"type": "float64"},
|
|
|
|
|
"mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"segments": {"type": "float32", "shape": [2, 2]},
|
|
|
|
|
"data": {"type": "bytes"}}
|
|
|
|
|
writer.add_schema(schema, "data is so cool")
|
|
|
|
|
writer.write_raw_data(data)
|
|
|
|
|
writer.commit()
|
|
|
|
|
|
|
|
|
|
reader = FileReader(mindrecord_file_name)
|
|
|
|
|
count = 0
|
|
|
|
|
for index, x in enumerate(reader.get_next()):
|
|
|
|
|
assert len(x) == 6
|
|
|
|
|
for field in x:
|
|
|
|
|
if isinstance(x[field], np.ndarray):
|
|
|
|
|
assert (x[field] == data[count][field]).all()
|
|
|
|
|
else:
|
|
|
|
|
assert x[field] == data[count][field]
|
|
|
|
|
count = count + 1
|
|
|
|
|
logger.info("#item{}: {}".format(index, x))
|
|
|
|
|
assert count == 6
|
|
|
|
|
reader.close()
|
|
|
|
|
|
|
|
|
|
os.remove("{}".format(mindrecord_file_name))
|
|
|
|
|
os.remove("{}.db".format(mindrecord_file_name))
|
|
|
|
|
|
|
|
|
|
def test_write_read_process_with_define_index_field():
|
|
|
|
|
mindrecord_file_name = "test.mindrecord"
|
|
|
|
|
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes abc", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes def", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes ghi", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes jkl", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes mno", encoding='UTF-8')},
|
|
|
|
|
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
|
|
|
|
|
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
|
|
|
|
|
"data": bytes("image bytes pqr", encoding='UTF-8')}
|
|
|
|
|
]
|
|
|
|
|
writer = FileWriter(mindrecord_file_name)
|
|
|
|
|
schema = {"file_name": {"type": "string"},
|
|
|
|
|
"label": {"type": "int32"},
|
|
|
|
|
"score": {"type": "float64"},
|
|
|
|
|
"mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"segments": {"type": "float32", "shape": [2, 2]},
|
|
|
|
|
"data": {"type": "bytes"}}
|
|
|
|
|
writer.add_schema(schema, "data is so cool")
|
|
|
|
|
writer.add_index(["label"])
|
|
|
|
|
writer.write_raw_data(data)
|
|
|
|
|
writer.commit()
|
|
|
|
|
|
|
|
|
|
reader = FileReader(mindrecord_file_name)
|
|
|
|
|
count = 0
|
|
|
|
|
for index, x in enumerate(reader.get_next()):
|
|
|
|
|
assert len(x) == 6
|
|
|
|
|
for field in x:
|
|
|
|
|
if isinstance(x[field], np.ndarray):
|
|
|
|
|
assert (x[field] == data[count][field]).all()
|
|
|
|
|
else:
|
|
|
|
|
assert x[field] == data[count][field]
|
|
|
|
|
count = count + 1
|
|
|
|
|
logger.info("#item{}: {}".format(index, x))
|
|
|
|
|
assert count == 6
|
|
|
|
|
reader.close()
|
|
|
|
|
|
|
|
|
|
os.remove("{}".format(mindrecord_file_name))
|
|
|
|
|
os.remove("{}.db".format(mindrecord_file_name))
|
|
|
|
|
|
|
|
|
|
def test_cv_file_writer_tutorial():
|
|
|
|
|
"""tutorial for cv dataset writer."""
|
|
|
|
|
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
|
|
|
@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial():
|
|
|
|
|
assert len(row1[0]) == 3
|
|
|
|
|
assert row1[0]['label'] == 822
|
|
|
|
|
|
|
|
|
|
def test_cv_page_reader_tutorial_by_file_name():
|
|
|
|
|
"""tutorial for cv page reader."""
|
|
|
|
|
reader = MindPage(CV_FILE_NAME + "0")
|
|
|
|
|
fields = reader.get_category_fields()
|
|
|
|
|
assert fields == ['file_name', 'label'],\
|
|
|
|
|
'failed on getting candidate category fields.'
|
|
|
|
|
|
|
|
|
|
ret = reader.set_category_field("file_name")
|
|
|
|
|
assert ret == SUCCESS, 'failed on setting category field.'
|
|
|
|
|
|
|
|
|
|
info = reader.read_category_info()
|
|
|
|
|
logger.info("category info: {}".format(info))
|
|
|
|
|
|
|
|
|
|
row = reader.read_at_page_by_id(0, 0, 1)
|
|
|
|
|
assert len(row) == 1
|
|
|
|
|
assert len(row[0]) == 3
|
|
|
|
|
assert row[0]['label'] == 490
|
|
|
|
|
|
|
|
|
|
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
|
|
|
|
|
assert len(row1) == 1
|
|
|
|
|
assert len(row1[0]) == 3
|
|
|
|
|
assert row1[0]['label'] == 13
|
|
|
|
|
|
|
|
|
|
def test_cv_page_reader_tutorial_new_api():
|
|
|
|
|
"""tutorial for cv page reader."""
|
|
|
|
|
reader = MindPage(CV_FILE_NAME + "0")
|
|
|
|
|
fields = reader.candidate_fields
|
|
|
|
|
assert fields == ['file_name', 'label'],\
|
|
|
|
|
'failed on getting candidate category fields.'
|
|
|
|
|
|
|
|
|
|
reader.category_field = "file_name"
|
|
|
|
|
|
|
|
|
|
info = reader.read_category_info()
|
|
|
|
|
logger.info("category info: {}".format(info))
|
|
|
|
|
|
|
|
|
|
row = reader.read_at_page_by_id(0, 0, 1)
|
|
|
|
|
assert len(row) == 1
|
|
|
|
|
assert len(row[0]) == 3
|
|
|
|
|
assert row[0]['label'] == 490
|
|
|
|
|
|
|
|
|
|
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
|
|
|
|
|
assert len(row1) == 1
|
|
|
|
|
assert len(row1[0]) == 3
|
|
|
|
|
assert row1[0]['label'] == 13
|
|
|
|
|
|
|
|
|
|
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
|
|
|
|
for x in range(FILES_NUM)]
|
|
|
|
|
for x in paths:
|
|
|
|
|