diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 8b34ff280e..15682db22a 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -51,7 +51,8 @@ from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPi from .validators import check_lookup, check_jieba_add_dict, \ check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \ check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \ - check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow + check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow, \ + check_sentence_piece_tokenizer from ..core.datatypes import mstype_to_detype from ..core.validator_helpers import replace_none from ..transforms.c_transforms import TensorOperation @@ -325,7 +326,7 @@ class SentencePieceTokenizer(TextTensorOperation): Args: mode (Union[str, SentencePieceVocab]): If the input parameter is a file, then it is of type string. If the input parameter is a SentencePieceVocab object, then it is of type SentencePieceVocab. - out_type (Union[str, int]): The type of output. + out_type (SPieceTokenizerOutType): The type of output, the type is int or string Examples: >>> from mindspore.dataset.text import SentencePieceModel, SPieceTokenizerOutType @@ -335,7 +336,7 @@ class SentencePieceTokenizer(TextTensorOperation): >>> tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) >>> text_file_dataset = text_file_dataset.map(operations=tokenizer) """ - + @check_sentence_piece_tokenizer def __init__(self, mode, out_type): self.mode = mode self.out_type = out_type diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 7cd191c054..cf96b74643 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -515,3 +515,19 @@ def check_save_model(method): return method(self, *args, **kwargs) return new_method + +def check_sentence_piece_tokenizer(method): + + """A wrapper that wraps a parameter checker to the original function.""" + + from .utils import SPieceTokenizerOutType + @wraps(method) + def new_method(self, *args, **kwargs): + [mode, out_type], _ = parse_user_args(method, *args, **kwargs) + + type_check(mode, (str, cde.SentencePieceVocab), "mode is not an instance of str or cde.SentencePieceVocab.") + type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType") + + return method(self, *args, **kwargs) + + return new_method diff --git a/mindspore/dataset/vision/validators.py b/mindspore/dataset/vision/validators.py index 10459ceb12..126a7d545b 100644 --- a/mindspore/dataset/vision/validators.py +++ b/mindspore/dataset/vision/validators.py @@ -31,7 +31,8 @@ def check_crop_size(size): if isinstance(size, int): check_value(size, (1, FLOAT_MAX_INTEGER)) elif isinstance(size, (tuple, list)) and len(size) == 2: - for value in size: + for index, value in enumerate(size): + type_check(value, (int,), "size[{}]".format(index)) check_value(value, (1, FLOAT_MAX_INTEGER)) else: raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") @@ -93,6 +94,8 @@ def check_normalize_c_param(mean, std): def check_normalize_py_param(mean, std): + type_check(mean, (list, tuple), "mean") + type_check(std, (list, tuple), "std") if len(mean) != len(std): raise ValueError("Length of mean and std must be equal.") for mean_value in mean: