|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
import collections
|
|
|
|
|
from importlib import import_module
|
|
|
|
|
import os
|
|
|
|
|
from string import punctuation
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
@ -35,6 +36,27 @@ TFRECORD_FILE_NAME = "test.tfrecord"
|
|
|
|
|
MINDRECORD_FILE_NAME = "test.mindrecord"
|
|
|
|
|
PARTITION_NUM = 1
|
|
|
|
|
|
|
|
|
|
def cast_name(key):
|
|
|
|
|
"""
|
|
|
|
|
Cast schema names which containing special characters to valid names.
|
|
|
|
|
|
|
|
|
|
Here special characters means any characters in
|
|
|
|
|
'!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~
|
|
|
|
|
Valid names can only contain a-z, A-Z, and 0-9 and _
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
key (str): original key that might contains special characters.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str, casted key that replace the special characters with "_". i.e. if
|
|
|
|
|
key is "a b" then returns "a_b".
|
|
|
|
|
"""
|
|
|
|
|
special_symbols = set('{}{}'.format(punctuation, ' '))
|
|
|
|
|
special_symbols.remove('_')
|
|
|
|
|
new_key = ['_' if x in special_symbols else x for x in key]
|
|
|
|
|
casted_key = ''.join(new_key)
|
|
|
|
|
return casted_key
|
|
|
|
|
|
|
|
|
|
def verify_data(transformer, reader):
|
|
|
|
|
"""Verify the data by read from mindrecord"""
|
|
|
|
|
tf_iter = transformer.tfrecord_iterator()
|
|
|
|
@ -43,14 +65,14 @@ def verify_data(transformer, reader):
|
|
|
|
|
count = 0
|
|
|
|
|
for tf_item, mr_item in zip(tf_iter, mr_iter):
|
|
|
|
|
count = count + 1
|
|
|
|
|
assert len(tf_item) == 6
|
|
|
|
|
assert len(mr_item) == 6
|
|
|
|
|
assert len(tf_item) == len(mr_item)
|
|
|
|
|
for key, value in tf_item.items():
|
|
|
|
|
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key]))
|
|
|
|
|
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value,
|
|
|
|
|
mr_item[cast_name(key)]))
|
|
|
|
|
if isinstance(value, np.ndarray):
|
|
|
|
|
assert (value == mr_item[key]).all()
|
|
|
|
|
assert (value == mr_item[cast_name(key)]).all()
|
|
|
|
|
else:
|
|
|
|
|
assert value == mr_item[key]
|
|
|
|
|
assert value == mr_item[cast_name(key)]
|
|
|
|
|
assert count == 10
|
|
|
|
|
|
|
|
|
|
def generate_tfrecord():
|
|
|
|
@ -102,6 +124,39 @@ def generate_tfrecord():
|
|
|
|
|
writer.close()
|
|
|
|
|
logger.info("Write {} rows in tfrecord.".format(example_count))
|
|
|
|
|
|
|
|
|
|
def generate_tfrecord_with_special_field_name():
|
|
|
|
|
def create_int_feature(values):
|
|
|
|
|
if isinstance(values, list):
|
|
|
|
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
|
|
|
|
|
else:
|
|
|
|
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int
|
|
|
|
|
return feature
|
|
|
|
|
|
|
|
|
|
def create_bytes_feature(values):
|
|
|
|
|
if isinstance(values, bytes):
|
|
|
|
|
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes
|
|
|
|
|
else:
|
|
|
|
|
# values: string
|
|
|
|
|
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
|
|
|
|
|
return feature
|
|
|
|
|
|
|
|
|
|
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
example_count = 0
|
|
|
|
|
for i in range(10):
|
|
|
|
|
label = i
|
|
|
|
|
image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8")
|
|
|
|
|
|
|
|
|
|
features = collections.OrderedDict()
|
|
|
|
|
features["image/class/label"] = create_int_feature(label)
|
|
|
|
|
features["image/encoded"] = create_bytes_feature(image_bytes)
|
|
|
|
|
|
|
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
|
|
|
|
writer.write(tf_example.SerializeToString())
|
|
|
|
|
example_count += 1
|
|
|
|
|
writer.close()
|
|
|
|
|
logger.info("Write {} rows in tfrecord.".format(example_count))
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord():
|
|
|
|
|
"""test transform tfrecord to mindrecord."""
|
|
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
|
|
|
@ -398,3 +453,110 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
|
|
|
|
|
"""test transform tfrecord to mindrecord."""
|
|
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
|
|
|
|
# skip the test
|
|
|
|
|
logger.warning("Module tensorflow is not found or version wrong, \
|
|
|
|
|
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
generate_tfrecord()
|
|
|
|
|
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
|
|
|
|
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
|
|
|
|
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
|
|
|
|
|
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
|
|
|
|
|
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
|
|
|
|
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME)
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
|
|
|
|
MINDRECORD_FILE_NAME, feature_dict, ["int64_list"])
|
|
|
|
|
tfrecord_transformer.transform()
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME)
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
|
|
|
|
|
"""test transform tfrecord to mindrecord."""
|
|
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
|
|
|
|
# skip the test
|
|
|
|
|
logger.warning("Module tensorflow is not found or version wrong, \
|
|
|
|
|
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
generate_tfrecord()
|
|
|
|
|
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
|
|
|
|
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
|
|
|
|
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
|
|
|
|
|
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
|
|
|
|
|
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
|
|
|
|
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME)
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
|
|
|
|
MINDRECORD_FILE_NAME, feature_dict, "")
|
|
|
|
|
tfrecord_transformer.transform()
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME)
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord_with_special_field_name():
|
|
|
|
|
"""test transform tfrecord to mindrecord."""
|
|
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
|
|
|
|
# skip the test
|
|
|
|
|
logger.warning("Module tensorflow is not found or version wrong, \
|
|
|
|
|
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
generate_tfrecord_with_special_field_name()
|
|
|
|
|
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|
|
|
|
|
|
feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64),
|
|
|
|
|
"image/encoded": tf.io.FixedLenFeature([], tf.string),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME)
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
|
|
|
|
MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"])
|
|
|
|
|
tfrecord_transformer.transform()
|
|
|
|
|
|
|
|
|
|
assert os.path.exists(MINDRECORD_FILE_NAME)
|
|
|
|
|
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
|
|
|
|
verify_data(tfrecord_transformer, fr_mindrecord)
|
|
|
|
|
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME)
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db")
|
|
|
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
|
|
|
|