|
|
@ -16,6 +16,7 @@
|
|
|
|
This is the test module for saveOp.
|
|
|
|
This is the test module for saveOp.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
from string import punctuation
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore.mindrecord import FileWriter
|
|
|
|
from mindspore.mindrecord import FileWriter
|
|
|
@ -24,7 +25,7 @@ import pytest
|
|
|
|
|
|
|
|
|
|
|
|
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
|
|
|
|
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
|
|
|
|
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
|
|
|
|
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
|
|
|
|
|
|
|
|
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
|
|
|
|
FILES_NUM = 1
|
|
|
|
FILES_NUM = 1
|
|
|
|
num_readers = 1
|
|
|
|
num_readers = 1
|
|
|
|
|
|
|
|
|
|
|
@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
|
|
|
|
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
|
|
|
|
d1.save(CV_FILE_NAME2, 1, "tfrecord")
|
|
|
|
d1.save(CV_FILE_NAME2, 1, "tfrecord")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cast_name(key):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Cast schema names which containing special characters to valid names.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
special_symbols = set('{}{}'.format(punctuation, ' '))
|
|
|
|
|
|
|
|
special_symbols.remove('_')
|
|
|
|
|
|
|
|
new_key = ['_' if x in special_symbols else x for x in key]
|
|
|
|
|
|
|
|
casted_key = ''.join(new_key)
|
|
|
|
|
|
|
|
return casted_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_07():
|
|
|
|
|
|
|
|
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
|
|
|
|
|
|
|
os.remove("{}".format(CV_FILE_NAME2))
|
|
|
|
|
|
|
|
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
|
|
|
|
|
|
|
os.remove("{}.db".format(CV_FILE_NAME2))
|
|
|
|
|
|
|
|
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
|
|
|
|
|
|
|
|
tf_data = []
|
|
|
|
|
|
|
|
for x in d1.create_dict_iterator():
|
|
|
|
|
|
|
|
tf_data.append(x)
|
|
|
|
|
|
|
|
d1.save(CV_FILE_NAME2, FILES_NUM)
|
|
|
|
|
|
|
|
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
|
|
|
|
|
|
|
|
num_parallel_workers=num_readers,
|
|
|
|
|
|
|
|
shuffle=False)
|
|
|
|
|
|
|
|
mr_data = []
|
|
|
|
|
|
|
|
for x in d2.create_dict_iterator():
|
|
|
|
|
|
|
|
mr_data.append(x)
|
|
|
|
|
|
|
|
count = 0
|
|
|
|
|
|
|
|
for x in tf_data:
|
|
|
|
|
|
|
|
for k, v in x.items():
|
|
|
|
|
|
|
|
if isinstance(v, np.ndarray):
|
|
|
|
|
|
|
|
assert (v == mr_data[count][cast_name(k)]).all()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert v == mr_data[count][cast_name(k)]
|
|
|
|
|
|
|
|
count += 1
|
|
|
|
|
|
|
|
assert count == 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("{}".format(CV_FILE_NAME2)):
|
|
|
|
|
|
|
|
os.remove("{}".format(CV_FILE_NAME2))
|
|
|
|
|
|
|
|
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
|
|
|
|
|
|
|
|
os.remove("{}.db".format(CV_FILE_NAME2))
|
|
|
|