|
|
|
@ -30,7 +30,7 @@ except ModuleNotFoundError:
|
|
|
|
|
|
|
|
|
|
__all__ = ['TFRecordToMR']
|
|
|
|
|
|
|
|
|
|
SupportedTensorFlowVersion = '1.13.0-rc1'
|
|
|
|
|
SupportedTensorFlowVersion = '2.1.0'
|
|
|
|
|
|
|
|
|
|
def _cast_type(value):
|
|
|
|
|
"""
|
|
|
|
@ -210,84 +210,30 @@ class TFRecordToMR:
|
|
|
|
|
else:
|
|
|
|
|
ms_dict[cast_key] = float(val.numpy())
|
|
|
|
|
|
|
|
|
|
def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val):
|
|
|
|
|
"""
|
|
|
|
|
put data in ms_dict when field type is string
|
|
|
|
|
However, we have to make change due to the different structure of old version
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(val, (bytes, str)):
|
|
|
|
|
if isinstance(val, (np.ndarray, list)):
|
|
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
|
|
|
|
|
if self.feature_dict[key].dtype == tf.string:
|
|
|
|
|
if cast_key in self.bytes_fields_list:
|
|
|
|
|
ms_dict[cast_key] = val
|
|
|
|
|
else:
|
|
|
|
|
ms_dict[cast_key] = val.decode("utf-8")
|
|
|
|
|
else:
|
|
|
|
|
ms_dict[cast_key] = val
|
|
|
|
|
else:
|
|
|
|
|
if _cast_type(self.feature_dict[key].dtype).startswith("int"):
|
|
|
|
|
ms_dict[cast_key] = int(val)
|
|
|
|
|
else:
|
|
|
|
|
ms_dict[cast_key] = float(val)
|
|
|
|
|
|
|
|
|
|
def tfrecord_iterator_oldversion(self):
|
|
|
|
|
"""
|
|
|
|
|
Yield a dict with key to be fields in schema, and value to be data.
|
|
|
|
|
This function is for old version tensorflow whose version number < 2.1.0
|
|
|
|
|
"""
|
|
|
|
|
dataset = tf.data.TFRecordDataset(self.source)
|
|
|
|
|
dataset = dataset.map(self._parse_record)
|
|
|
|
|
iterator = dataset.make_one_shot_iterator()
|
|
|
|
|
with tf.Session() as sess:
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
ms_dict = {}
|
|
|
|
|
sample = iterator.get_next()
|
|
|
|
|
sample = sess.run(sample)
|
|
|
|
|
for key, val in sample.items():
|
|
|
|
|
cast_key = _cast_name(key)
|
|
|
|
|
if cast_key in self.scalar_set:
|
|
|
|
|
self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val)
|
|
|
|
|
else:
|
|
|
|
|
if not isinstance(val, np.ndarray) and not isinstance(val, list):
|
|
|
|
|
raise ValueError("The response key: {}, value: {} from "
|
|
|
|
|
"TFRecord should be a ndarray or "
|
|
|
|
|
"list.".format(key, val))
|
|
|
|
|
# list set
|
|
|
|
|
ms_dict[cast_key] = \
|
|
|
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
|
|
|
|
|
yield ms_dict
|
|
|
|
|
except tf.errors.OutOfRangeError:
|
|
|
|
|
break
|
|
|
|
|
except tf.errors.InvalidArgumentError:
|
|
|
|
|
raise ValueError("TFRecord feature_dict parameter error.")
|
|
|
|
|
|
|
|
|
|
def tfrecord_iterator(self):
|
|
|
|
|
"""Yield a dict with key to be fields in schema, and value to be data."""
|
|
|
|
|
dataset = tf.data.TFRecordDataset(self.source)
|
|
|
|
|
dataset = dataset.map(self._parse_record)
|
|
|
|
|
iterator = dataset.__iter__()
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
index_id = 0
|
|
|
|
|
try:
|
|
|
|
|
for features in iterator:
|
|
|
|
|
ms_dict = {}
|
|
|
|
|
sample = iterator.get_next()
|
|
|
|
|
for key, val in sample.items():
|
|
|
|
|
index_id = index_id + 1
|
|
|
|
|
for key, val in features.items():
|
|
|
|
|
cast_key = _cast_name(key)
|
|
|
|
|
if cast_key in self.scalar_set:
|
|
|
|
|
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
|
|
|
|
|
else:
|
|
|
|
|
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list):
|
|
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \
|
|
|
|
|
"list.".format(key, val))
|
|
|
|
|
"list.".format(key, val))
|
|
|
|
|
# list set
|
|
|
|
|
ms_dict[cast_key] = \
|
|
|
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
|
|
|
|
|
yield ms_dict
|
|
|
|
|
except tf.errors.OutOfRangeError:
|
|
|
|
|
break
|
|
|
|
|
except tf.errors.InvalidArgumentError:
|
|
|
|
|
raise ValueError("TFRecord feature_dict parameter error.")
|
|
|
|
|
except tf.errors.InvalidArgumentError:
|
|
|
|
|
raise ValueError("TFRecord feature_dict parameter error.")
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
"""
|
|
|
|
@ -301,11 +247,10 @@ class TFRecordToMR:
|
|
|
|
|
.format(self.mindrecord_schema, self.feature_dict))
|
|
|
|
|
|
|
|
|
|
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
|
|
|
|
|
if tf.__version__ < '2.0.0':
|
|
|
|
|
tf_iter = self.tfrecord_iterator_oldversion()
|
|
|
|
|
else:
|
|
|
|
|
tf_iter = self.tfrecord_iterator()
|
|
|
|
|
|
|
|
|
|
tf_iter = self.tfrecord_iterator()
|
|
|
|
|
batch_size = 256
|
|
|
|
|
|
|
|
|
|
transform_count = 0
|
|
|
|
|
while True:
|
|
|
|
|
data_list = []
|
|
|
|
|