diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 30c0263ab4..5d0ad6f00c 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -168,6 +168,17 @@ def create_node(node): # Find a matching Dataset class and call the constructor with the corresponding args. # When a new Dataset class is introduced, another if clause and parsing code needs to be added. # Dataset Source Ops (in alphabetical order) + pyobj = create_dataset_node(pyclass, node, dataset_op) + if not pyobj: + # Dataset Ops (in alphabetical order) + pyobj = create_dataset_operation_node(node, dataset_op) + + return pyobj + + +def create_dataset_node(pyclass, node, dataset_op): + """Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object""" + pyobj = None if dataset_op == 'CelebADataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) @@ -189,7 +200,7 @@ def create_node(node): elif dataset_op == 'ClueDataset': shuffle = to_shuffle_mode(node.get('shuffle')) - if shuffle is not None and isinstance(shuffle, str): + if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('task'), @@ -205,7 +216,7 @@ def create_node(node): elif dataset_op == 'CSVDataset': shuffle = to_shuffle_mode(node.get('shuffle')) - if shuffle is not None and isinstance(shuffle, str): + if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('field_delim'), @@ -237,7 +248,7 @@ def create_node(node): elif dataset_op == 'TextFileDataset': shuffle = to_shuffle_mode(node.get('shuffle')) - if shuffle is not None and isinstance(shuffle, str): + if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], num_samples, @@ -246,7 +257,7 @@ def create_node(node): elif dataset_op == 'TFRecordDataset': shuffle = to_shuffle_mode(node.get('shuffle')) - if shuffle is not None and isinstance(shuffle, str): + if isinstance(shuffle, str): shuffle = de.Shuffle(shuffle) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), @@ -260,8 +271,13 @@ def create_node(node): num_samples, node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) - # Dataset Ops (in alphabetical order) - elif dataset_op == 'Batch': + return pyobj + + +def create_dataset_operation_node(node, dataset_op): + """Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object""" + pyobj = None + if dataset_op == 'Batch': pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) elif dataset_op == 'Map': @@ -292,7 +308,7 @@ def create_node(node): pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) elif dataset_op == 'Zip': - # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. + # Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller. pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) else: diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index e18ab78703..d8fb88d4b2 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -24,7 +24,6 @@ import mindspore._c_dataengine as cde from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model - __all__ = [ "Vocab", "SentencePieceVocab", "to_str", "to_bytes" ] @@ -66,7 +65,7 @@ class Vocab(cde.Vocab): is specified and special_first is set to True, special_tokens will be prepended (default=True). Returns: - Vocab, Vocab object built from dataset. + Vocab, vocab built from the dataset. """ return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first) @@ -82,6 +81,9 @@ class Vocab(cde.Vocab): special_tokens=["",""] (default=None, no special tokens will be added). special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens is specified and special_first is set to True, special_tokens will be prepended (default=True). + + Returns: + Vocab, vocab built from the `list`. """ if special_tokens is None: special_tokens = [] @@ -103,6 +105,9 @@ class Vocab(cde.Vocab): special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens is specified and special_first is set to True, special_tokens will be prepended (default=True). + + Returns: + Vocab, vocab built from the file. """ if vocab_size is None: vocab_size = -1 @@ -119,6 +124,9 @@ class Vocab(cde.Vocab): Args: word_dict (dict): dict contains word and id pairs, where word should be str and id be int. id is recommended to start from 0 and be continuous. ValueError will be raised if id is negative. + + Returns: + Vocab, vocab built from the `dict`. """ return super().from_dict(word_dict) @@ -147,7 +155,7 @@ class SentencePieceVocab(cde.SentencePieceVocab): params(dict): A dictionary with no incoming parameters. Returns: - SentencePiece, SentencePiece object from dataset. + SentencePieceVocab, vocab built from the dataset. """ return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage, @@ -174,6 +182,9 @@ class SentencePieceVocab(cde.SentencePieceVocab): input_sentence_size 0 max_sentencepiece_length 16 + + Returns: + SentencePieceVocab, vocab built from the file. """ return super().from_file(file_path, vocab_size, character_coverage, DE_C_INTER_SENTENCEPIECE_MODE[model_type], params) @@ -189,7 +200,7 @@ class SentencePieceVocab(cde.SentencePieceVocab): path(str): Path to store model. filename(str): The name of the file. """ - return super().save_model(vocab, path, filename) + super().save_model(vocab, path, filename) def to_str(array, encoding='utf8'): diff --git a/mindspore/mindrecord/filereader.py b/mindspore/mindrecord/filereader.py index 5869f87de5..df88939538 100644 --- a/mindspore/mindrecord/filereader.py +++ b/mindspore/mindrecord/filereader.py @@ -38,6 +38,7 @@ class FileReader: Raises: ParamValueError: If file_name, num_consumer or columns is invalid. """ + def __init__(self, file_name, num_consumer=4, columns=None, operator=None): if isinstance(file_name, list): for f in file_name: @@ -66,7 +67,6 @@ class FileReader: self._header = ShardHeader(self._reader.get_header()) self._reader.launch() - def get_next(self): """ Yield a batch of data according to columns at a time. @@ -85,4 +85,4 @@ class FileReader: def close(self): """Stop reader worker and close File.""" - return self._reader.close() + self._reader.close() diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index 2c4b3d85e1..54571f8871 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -69,7 +69,7 @@ class TFRecordToMR: Args: source (str): the TFRecord file to be transformed. - destination (str): the MindRecord file path to tranform into. + destination (str): the MindRecord file path to transform into. feature_dict (dict): a dictionary that states the feature type, e.g. feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ "yyyy": tf.io.FixedLenFeature([], tf.int64)} @@ -90,31 +90,14 @@ class TFRecordToMR: try: self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord except ModuleNotFoundError: - self.tf = None - if not self.tf: raise Exception("Module tensorflow is not found, please use pip install it.") if self.tf.__version__ < SupportedTensorFlowVersion: raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) - if not isinstance(source, str): - raise ValueError("Parameter source must be string.") - check_filename(source) - - if not isinstance(destination, str): - raise ValueError("Parameter destination must be string.") - check_filename(destination) - + self._check_input(source, destination, feature_dict) self.source = source self.destination = destination - - if feature_dict is None or not isinstance(feature_dict, dict): - raise ValueError("Parameter feature_dict is None or not dict.") - - for key, val in feature_dict.items(): - if not isinstance(val, self.tf.io.FixedLenFeature): - raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) - self.feature_dict = feature_dict bytes_fields_list = [] @@ -162,6 +145,23 @@ class TFRecordToMR: mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} self.mindrecord_schema = mindrecord_schema + def _check_input(self, source, destination, feature_dict): + """Validation check for inputs of init method""" + if not isinstance(source, str): + raise ValueError("Parameter source must be string.") + check_filename(source) + + if not isinstance(destination, str): + raise ValueError("Parameter destination must be string.") + check_filename(destination) + + if feature_dict is None or not isinstance(feature_dict, dict): + raise ValueError("Parameter feature_dict is None or not dict.") + + for _, val in feature_dict.items(): + if not isinstance(val, self.tf.io.FixedLenFeature): + raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) + def _parse_record(self, example): """Returns features for a single example""" features = self.tf.io.parse_single_example(example, features=self.feature_dict) @@ -206,6 +206,9 @@ class TFRecordToMR: """ Yield a dict with key to be fields in schema, and value to be data. This function is for old version tensorflow whose version number < 2.1.0 + + Yields: + dict, data dictionary whose keys are the same as columns. """ dataset = self.tf.data.TFRecordDataset(self.source) dataset = dataset.map(self._parse_record) @@ -235,7 +238,12 @@ class TFRecordToMR: raise ValueError("TFRecord feature_dict parameter error.") def tfrecord_iterator(self): - """Yield a dictionary whose keys are fields in schema.""" + """ + Yield a dictionary whose keys are fields in schema. + + Yields: + dict, data dictionary whose keys are the same as columns. + """ dataset = self.tf.data.TFRecordDataset(self.source) dataset = dataset.map(self._parse_record) iterator = dataset.__iter__() @@ -265,7 +273,7 @@ class TFRecordToMR: Execute transformation from TFRecord to MindRecord. Returns: - MSRStatus, whether TFRecord is successfuly transformed to MindRecord. + MSRStatus, whether TFRecord is successfully transformed to MindRecord. """ writer = FileWriter(self.destination) logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"