From 6364d56ebec7f6a4840d0583708a9f28134d4321 Mon Sep 17 00:00:00 2001 From: xsmq Date: Mon, 12 Oct 2020 14:16:04 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!6995=20?= =?UTF-8?q?:=20Support=20the=20transition=20of=20datarecord=20from=201.x.x?= =?UTF-8?q?=20version=20Tensorflow=20to=20Mindspore=20'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- graphengine | 2 +- mindspore/mindrecord/tools/tfrecord_to_mr.py | 79 +++---------------- .../python/mindrecord/test_tfrecord_to_mr.py | 13 +-- 3 files changed, 16 insertions(+), 78 deletions(-) diff --git a/graphengine b/graphengine index 14db109491..7a75f024d5 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 14db109491bc81473905a5eb9e82f6234aca419b +Subproject commit 7a75f024d5a70c51b6428008587c4125bc015349 diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index 5aad3b075d..ce5b8400f0 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -30,7 +30,7 @@ except ModuleNotFoundError: __all__ = ['TFRecordToMR'] -SupportedTensorFlowVersion = '1.13.0-rc1' +SupportedTensorFlowVersion = '2.1.0' def _cast_type(value): """ @@ -210,84 +210,30 @@ class TFRecordToMR: else: ms_dict[cast_key] = float(val.numpy()) - def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val): - """ - put data in ms_dict when field type is string - However, we have to make change due to the different structure of old version - """ - if isinstance(val, (bytes, str)): - if isinstance(val, (np.ndarray, list)): - raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) - if self.feature_dict[key].dtype == tf.string: - if cast_key in self.bytes_fields_list: - ms_dict[cast_key] = val - else: - ms_dict[cast_key] = val.decode("utf-8") - else: - ms_dict[cast_key] = val - else: - if _cast_type(self.feature_dict[key].dtype).startswith("int"): - ms_dict[cast_key] = int(val) - else: - ms_dict[cast_key] = float(val) - - def tfrecord_iterator_oldversion(self): - """ - Yield a dict with key to be fields in schema, and value to be data. - This function is for old version tensorflow whose version number < 2.1.0 - """ - dataset = tf.data.TFRecordDataset(self.source) - dataset = dataset.map(self._parse_record) - iterator = dataset.make_one_shot_iterator() - with tf.Session() as sess: - while True: - try: - ms_dict = {} - sample = iterator.get_next() - sample = sess.run(sample) - for key, val in sample.items(): - cast_key = _cast_name(key) - if cast_key in self.scalar_set: - self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val) - else: - if not isinstance(val, np.ndarray) and not isinstance(val, list): - raise ValueError("The response key: {}, value: {} from " - "TFRecord should be a ndarray or " - "list.".format(key, val)) - # list set - ms_dict[cast_key] = \ - np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) - yield ms_dict - except tf.errors.OutOfRangeError: - break - except tf.errors.InvalidArgumentError: - raise ValueError("TFRecord feature_dict parameter error.") - def tfrecord_iterator(self): """Yield a dict with key to be fields in schema, and value to be data.""" dataset = tf.data.TFRecordDataset(self.source) dataset = dataset.map(self._parse_record) iterator = dataset.__iter__() - while True: - try: + index_id = 0 + try: + for features in iterator: ms_dict = {} - sample = iterator.get_next() - for key, val in sample.items(): + index_id = index_id + 1 + for key, val in features.items(): cast_key = _cast_name(key) if cast_key in self.scalar_set: self._get_data_when_scalar_field(ms_dict, cast_key, key, val) else: if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ - "list.".format(key, val)) + "list.".format(key, val)) # list set ms_dict[cast_key] = \ np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) yield ms_dict - except tf.errors.OutOfRangeError: - break - except tf.errors.InvalidArgumentError: - raise ValueError("TFRecord feature_dict parameter error.") + except tf.errors.InvalidArgumentError: + raise ValueError("TFRecord feature_dict parameter error.") def run(self): """ @@ -301,11 +247,10 @@ class TFRecordToMR: .format(self.mindrecord_schema, self.feature_dict)) writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") - if tf.__version__ < '2.0.0': - tf_iter = self.tfrecord_iterator_oldversion() - else: - tf_iter = self.tfrecord_iterator() + + tf_iter = self.tfrecord_iterator() batch_size = 256 + transform_count = 0 while True: data_list = [] diff --git a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py index e9f03a9cca..cfd0d53a49 100644 --- a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py +++ b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py @@ -23,7 +23,7 @@ from mindspore import log as logger from mindspore.mindrecord import FileReader from mindspore.mindrecord import TFRecordToMR -SupportedTensorFlowVersion = '1.13.0-rc1' +SupportedTensorFlowVersion = '2.1.0' try: tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord @@ -58,15 +58,8 @@ def cast_name(key): return casted_key def verify_data(transformer, reader): - """ - Verify the data by read from mindrecord - If in 1.x.x version, use old version to receive that iteration - """ - - if tf.__version__ < '2.0.0': - tf_iter = transformer.tfrecord_iterator_oldversion() - else: - tf_iter = transformer.tfrecord_iterator() + """Verify the data by read from mindrecord""" + tf_iter = transformer.tfrecord_iterator() mr_iter = reader.get_next() count = 0