From 5ede9ffccae453348cac4af21b898de7f734639d Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 2 Feb 2021 14:51:09 +0800 Subject: [PATCH] improve isinstance --- mindspore/_extends/parse/standard_method.py | 40 ++++++++++++------- .../python/pipeline/parse/test_isinstance.py | 7 ++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index ebd77f07cf..37917fe682 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -366,20 +366,32 @@ def check_type_same(x_type, base_type): Tensor: mstype.tensor_type, Parameter: mstype.ref_type } - try: - if isinstance(base_type, list): - raise TypeError("The second arg of 'isinstance' must be a type or a tuple of types, but got a list") - if isinstance(base_type, tuple): - target_type = tuple(pytype_to_mstype[i] for i in base_type) - else: - target_type = (pytype_to_mstype[base_type],) - if (isinstance(x_type, mstype.Bool) and mstype.Int in target_type) or \ - (isinstance(x_type, mstype.ref_type) and mstype.tensor_type in target_type): - return True - return isinstance(x_type, target_type) - except KeyError: - raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " - f"Tensor, Parameter, or a tuple containing only these types, but got {base_type}") + + has_int = False + has_tensor = False + + def to_target_type(origin_type): + try: + if isinstance(origin_type, type): + ret_type = pytype_to_mstype[origin_type] + if ret_type == mstype.Int: + nonlocal has_int + has_int = True + if ret_type == mstype.tensor_type: + nonlocal has_tensor + has_tensor = True + return (ret_type,) + if isinstance(origin_type, tuple): + return tuple(to_target_type(i) for i in origin_type) + raise TypeError(f"The second arg of 'isinstance' must be a type or a tuple of types, " + f"but got a {type(origin_type).__name__}") + except KeyError: + raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " + f"Tensor, Parameter, or a tuple containing only these types, but got {origin_type}") + target_type = to_target_type(base_type) + if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor): + return True + return isinstance(x_type, target_type) @constexpr diff --git a/tests/ut/python/pipeline/parse/test_isinstance.py b/tests/ut/python/pipeline/parse/test_isinstance.py index 8abe40afca..99b40e07e4 100644 --- a/tests/ut/python/pipeline/parse/test_isinstance.py +++ b/tests/ut/python/pipeline/parse/test_isinstance.py @@ -43,10 +43,10 @@ def test_isinstance(): is_int = isinstance(self.int_member, int) is_float = isinstance(self.float_member, float) is_bool = isinstance(self.bool_member, bool) - bool_is_int = isinstance(self.bool_member, int) + bool_is_int = isinstance(self.bool_member, (((int,)), float)) is_string = isinstance(self.string_member, str) is_parameter = isinstance(self.weight, Parameter) - parameter_is_tensor = isinstance(self.weight, Tensor) + parameter_is_tensor = isinstance(self.weight, ((Tensor, float), int)) is_tensor_const = isinstance(self.tensor_member, Tensor) is_tensor_var = isinstance(x, Tensor) is_tuple_const = isinstance(self.tuple_member, tuple) @@ -88,8 +88,7 @@ def test_isinstance_not_supported(): net = Net() with pytest.raises(TypeError) as err: net() - assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \ - "or a tuple containing only these types, but got None" in str(err.value) + assert "The second arg of 'isinstance' must be a type or a tuple of types, but got a NoneType" in str(err.value) def test_isinstance_second_arg_is_list():