!11737 fix missing return descriptions in text.utils

From: @tiancixiao
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
pull/11737/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a78c8bad46

@ -168,6 +168,17 @@ def create_node(node):
# Find a matching Dataset class and call the constructor with the corresponding args.
# When a new Dataset class is introduced, another if clause and parsing code needs to be added.
# Dataset Source Ops (in alphabetical order)
pyobj = create_dataset_node(pyclass, node, dataset_op)
if not pyobj:
# Dataset Ops (in alphabetical order)
pyobj = create_dataset_operation_node(node, dataset_op)
return pyobj
def create_dataset_node(pyclass, node, dataset_op):
"""Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object"""
pyobj = None
if dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
@ -189,7 +200,7 @@ def create_node(node):
elif dataset_op == 'ClueDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('task'),
@ -205,7 +216,7 @@ def create_node(node):
elif dataset_op == 'CSVDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('field_delim'),
@ -237,7 +248,7 @@ def create_node(node):
elif dataset_op == 'TextFileDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], num_samples,
@ -246,7 +257,7 @@ def create_node(node):
elif dataset_op == 'TFRecordDataset':
shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'),
@ -260,8 +271,13 @@ def create_node(node):
num_samples, node.get('num_parallel_workers'), node.get('shuffle'),
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
# Dataset Ops (in alphabetical order)
elif dataset_op == 'Batch':
return pyobj
def create_dataset_operation_node(node, dataset_op):
"""Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object"""
pyobj = None
if dataset_op == 'Batch':
pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
elif dataset_op == 'Map':
@ -292,7 +308,7 @@ def create_node(node):
pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue'))
elif dataset_op == 'Zip':
# Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
# Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller.
pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
else:

@ -24,7 +24,6 @@ import mindspore._c_dataengine as cde
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model
__all__ = [
"Vocab", "SentencePieceVocab", "to_str", "to_bytes"
]
@ -66,7 +65,7 @@ class Vocab(cde.Vocab):
is specified and special_first is set to True, special_tokens will be prepended (default=True).
Returns:
Vocab, Vocab object built from dataset.
Vocab, vocab built from the dataset.
"""
return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
@ -82,6 +81,9 @@ class Vocab(cde.Vocab):
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
is specified and special_first is set to True, special_tokens will be prepended (default=True).
Returns:
Vocab, vocab built from the `list`.
"""
if special_tokens is None:
special_tokens = []
@ -103,6 +105,9 @@ class Vocab(cde.Vocab):
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
If special_tokens is specified and special_first is set to True,
special_tokens will be prepended (default=True).
Returns:
Vocab, vocab built from the file.
"""
if vocab_size is None:
vocab_size = -1
@ -119,6 +124,9 @@ class Vocab(cde.Vocab):
Args:
word_dict (dict): dict contains word and id pairs, where word should be str and id be int. id is recommended
to start from 0 and be continuous. ValueError will be raised if id is negative.
Returns:
Vocab, vocab built from the `dict`.
"""
return super().from_dict(word_dict)
@ -147,7 +155,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
params(dict): A dictionary with no incoming parameters.
Returns:
SentencePiece, SentencePiece object from dataset.
SentencePieceVocab, vocab built from the dataset.
"""
return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage,
@ -174,6 +182,9 @@ class SentencePieceVocab(cde.SentencePieceVocab):
input_sentence_size 0
max_sentencepiece_length 16
Returns:
SentencePieceVocab, vocab built from the file.
"""
return super().from_file(file_path, vocab_size, character_coverage,
DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
@ -189,7 +200,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
path(str): Path to store model.
filename(str): The name of the file.
"""
return super().save_model(vocab, path, filename)
super().save_model(vocab, path, filename)
def to_str(array, encoding='utf8'):

@ -38,6 +38,7 @@ class FileReader:
Raises:
ParamValueError: If file_name, num_consumer or columns is invalid.
"""
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
if isinstance(file_name, list):
for f in file_name:
@ -66,7 +67,6 @@ class FileReader:
self._header = ShardHeader(self._reader.get_header())
self._reader.launch()
def get_next(self):
"""
Yield a batch of data according to columns at a time.
@ -85,4 +85,4 @@ class FileReader:
def close(self):
"""Stop reader worker and close File."""
return self._reader.close()
self._reader.close()

@ -69,7 +69,7 @@ class TFRecordToMR:
Args:
source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into.
destination (str): the MindRecord file path to transform into.
feature_dict (dict): a dictionary that states the feature type, e.g.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
@ -90,31 +90,14 @@ class TFRecordToMR:
try:
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError:
self.tf = None
if not self.tf:
raise Exception("Module tensorflow is not found, please use pip install it.")
if self.tf.__version__ < SupportedTensorFlowVersion:
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion))
if not isinstance(source, str):
raise ValueError("Parameter source must be string.")
check_filename(source)
if not isinstance(destination, str):
raise ValueError("Parameter destination must be string.")
check_filename(destination)
self._check_input(source, destination, feature_dict)
self.source = source
self.destination = destination
if feature_dict is None or not isinstance(feature_dict, dict):
raise ValueError("Parameter feature_dict is None or not dict.")
for key, val in feature_dict.items():
if not isinstance(val, self.tf.io.FixedLenFeature):
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))
self.feature_dict = feature_dict
bytes_fields_list = []
@ -162,6 +145,23 @@ class TFRecordToMR:
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]}
self.mindrecord_schema = mindrecord_schema
def _check_input(self, source, destination, feature_dict):
"""Validation check for inputs of init method"""
if not isinstance(source, str):
raise ValueError("Parameter source must be string.")
check_filename(source)
if not isinstance(destination, str):
raise ValueError("Parameter destination must be string.")
check_filename(destination)
if feature_dict is None or not isinstance(feature_dict, dict):
raise ValueError("Parameter feature_dict is None or not dict.")
for _, val in feature_dict.items():
if not isinstance(val, self.tf.io.FixedLenFeature):
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))
def _parse_record(self, example):
"""Returns features for a single example"""
features = self.tf.io.parse_single_example(example, features=self.feature_dict)
@ -206,6 +206,9 @@ class TFRecordToMR:
"""
Yield a dict with key to be fields in schema, and value to be data.
This function is for old version tensorflow whose version number < 2.1.0
Yields:
dict, data dictionary whose keys are the same as columns.
"""
dataset = self.tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
@ -235,7 +238,12 @@ class TFRecordToMR:
raise ValueError("TFRecord feature_dict parameter error.")
def tfrecord_iterator(self):
"""Yield a dictionary whose keys are fields in schema."""
"""
Yield a dictionary whose keys are fields in schema.
Yields:
dict, data dictionary whose keys are the same as columns.
"""
dataset = self.tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__()
@ -265,7 +273,7 @@ class TFRecordToMR:
Execute transformation from TFRecord to MindRecord.
Returns:
MSRStatus, whether TFRecord is successfuly transformed to MindRecord.
MSRStatus, whether TFRecord is successfully transformed to MindRecord.
"""
writer = FileWriter(self.destination)
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"

Loading…
Cancel
Save