From 63185cb20f4d4b621b7556e75ef9237c7867a8a4 Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Tue, 14 Jul 2020 10:20:02 -0400 Subject: [PATCH] fix some validators errors address review cmts addr review cmts --- mindspore/dataset/engine/datasets.py | 3 +- mindspore/dataset/text/validators.py | 13 +++--- tests/ut/python/dataset/test_from_dataset.py | 6 +-- tests/ut/python/dataset/test_ngram_op.py | 47 +++++++++----------- tests/ut/python/dataset/test_vocab.py | 18 +++++--- 5 files changed, 44 insertions(+), 43 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index c1ef6a9922..7b9a166a07 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1563,7 +1563,7 @@ class BatchDataset(DatasetOp): Number, number of batches. """ child_size = self.children[0].get_dataset_size() - if child_size is not None: + if child_size is not None and isinstance(self.batch_size, int): if self.drop_remainder: return math.floor(child_size / self.batch_size) return math.ceil(child_size / self.batch_size) @@ -3915,7 +3915,6 @@ class RandomDataset(SourceDataset): return self.sampler.is_sharded() - class Schema: """ Class to represent a schema of dataset. diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 14c0ffe7c1..b0327f5609 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -23,7 +23,8 @@ import mindspore._c_dataengine as cde from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ - INT32_MAX, check_value + INT32_MAX, check_value, check_positive + def check_unique_list_of_words(words, arg_name): """Check that words is a list and each element is a str without any duplication""" @@ -109,7 +110,7 @@ def check_from_dict(method): for word, word_id in word_dict.items(): type_check(word, (str,), "word") type_check(word_id, (int,), "word_id") - check_value(word_id, (-1, INT32_MAX), "word_id") + check_value(word_id, (0, INT32_MAX), "word_id") return method(self, *args, **kwargs) return new_method @@ -196,7 +197,7 @@ def check_wordpiece_tokenizer(method): @wraps(method) def new_method(self, *args, **kwargs): - [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ =\ + [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \ parse_user_args(method, *args, **kwargs) if vocab is None: raise ValueError("vocab is not provided.") @@ -238,7 +239,7 @@ def check_basic_tokenizer(method): @wraps(method) def new_method(self, *args, **kwargs): - [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ =\ + [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \ parse_user_args(method, *args, **kwargs) if not isinstance(lower_case, bool): raise TypeError("Wrong input type for lower_case, should be boolean.") @@ -317,7 +318,7 @@ def check_from_dataset(method): type_check(top_k, (int, type(None)), "top_k") if isinstance(top_k, int): - check_value(top_k, (0, INT32_MAX), "top_k") + check_positive(top_k, "top_k") type_check(special_first, (bool,), "special_first") if special_tokens is not None: @@ -343,7 +344,7 @@ def check_ngram(method): for i, gram in enumerate(n): type_check(gram, (int,), "gram[{0}]".format(i)) - check_value(gram, (0, INT32_MAX), "gram_{}".format(i)) + check_positive(gram, "gram_{}".format(i)) if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( left_pad[1], int)): diff --git a/tests/ut/python/dataset/test_from_dataset.py b/tests/ut/python/dataset/test_from_dataset.py index 94a5a5df02..983052ea08 100644 --- a/tests/ut/python/dataset/test_from_dataset.py +++ b/tests/ut/python/dataset/test_from_dataset.py @@ -128,7 +128,7 @@ def test_from_dataset_exceptions(): data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) assert isinstance(vocab.text.Vocab) - except (TypeError, ValueError, RuntimeError) as e: + except (TypeError, ValueError) as e: assert s in str(e), str(e) test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") @@ -136,8 +136,8 @@ def test_from_dataset_exceptions(): "Argument top_k with value 1.2345 is not of type (, )") test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (,)") test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") - test_config("text", (2, 3), 0, "top_k needs to be positive number") - test_config([123], (2, 3), 0, "top_k needs to be positive number") + test_config("text", (2, 3), 0, "top_k must be greater than 0") + test_config([123], (2, 3), -1, "top_k must be greater than 0") if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_ngram_op.py b/tests/ut/python/dataset/test_ngram_op.py index 8887b67500..777fca8764 100644 --- a/tests/ut/python/dataset/test_ngram_op.py +++ b/tests/ut/python/dataset/test_ngram_op.py @@ -72,43 +72,36 @@ def test_simple_ngram(): def test_corner_cases(): """ testing various corner cases and exceptions""" - def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): + def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): def gen(texts): yield (np.array(texts.split(" "), dtype='S'),) - dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep)) - for data in dataset.create_dict_iterator(): - assert [d.decode("utf8") for d in data["text"]] == output_line, output_line + try: + dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep)) + for data in dataset.create_dict_iterator(): + return [d.decode("utf8") for d in data["text"]] + except (ValueError, TypeError) as e: + return str(e) # test tensor length smaller than n - test_config("Lone Star", ["Lone Star", "", "", ""], [2, 3, 4, 5]) + assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""] # test empty separator - test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="") + assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia'] # test separator with longer length - test_config("Beautiful British Columbia", ['Beautiful^-^British^-^Columbia'], 3, sep="^-^") + assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia'] # test left pad != right pad - test_config("Lone Star", ['The Lone Star State'], 4, ("The", 1), ("State", 1)) + assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State'] # test invalid n - try: - test_config("Yours to Discover", "", [0, [1]]) - except Exception as e: - assert "Argument gram[1] with value [1] is not of type (,)" in str(e) - # test empty n - try: - test_config("Yours to Discover", "", []) - except Exception as e: - assert "n needs to be a non-empty list" in str(e) + assert "gram[1] with value [1] is not of type (,)" in test_config("Yours to Discover", [1, [1]]) + assert "n needs to be a non-empty list" in test_config("Yours to Discover", []) # test invalid pad - try: - test_config("Yours to Discover", "", [1], ("str", -1)) - except Exception as e: - assert "padding width need to be positive numbers" in str(e) - # test invalid pad - try: - test_config("Yours to Discover", "", [1], ("str", "rts")) - except Exception as e: - assert "pad needs to be a tuple of (str, int)" in str(e) + assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1)) + assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts")) + # test 0 as in valid input + assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0) + assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0]) + assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0]) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py index 901a822d5e..0545181360 100644 --- a/tests/ut/python/dataset/test_vocab.py +++ b/tests/ut/python/dataset/test_vocab.py @@ -60,6 +60,15 @@ def test_from_dict_tutorial(): ind += 1 +def test_from_dict_exception(): + try: + vocab = text.Vocab.from_dict({"home": -1, "behind": 0}) + if not vocab: + raise ValueError("Vocab is None") + except ValueError as e: + assert "is not within the required interval" in str(e) + + def test_from_list(): def gen(texts): for word in texts.split(" "): @@ -74,13 +83,11 @@ def test_from_list(): for d in data.create_dict_iterator(): res.append(d["text"].item()) return res - except ValueError as e: - return str(e) - except RuntimeError as e: - return str(e) - except TypeError as e: + except (ValueError, RuntimeError, TypeError) as e: return str(e) + # test basic default config, special_token=None, unknown_token=None + assert test_config("w1 w2 w3", ["w1", "w2", "w3"], None, True, None) == [0, 1, 2] # test normal operations assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1] assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4] @@ -129,6 +136,7 @@ def test_from_file(): if __name__ == '__main__': + test_from_dict_exception() test_from_list_tutorial() test_from_file_tutorial() test_from_dict_tutorial()