!11975 improve isinstance

From: @zhangbuxue
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/11975/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 840d819d08

@ -366,20 +366,32 @@ def check_type_same(x_type, base_type):
Tensor: mstype.tensor_type, Tensor: mstype.tensor_type,
Parameter: mstype.ref_type Parameter: mstype.ref_type
} }
try:
if isinstance(base_type, list): has_int = False
raise TypeError("The second arg of 'isinstance' must be a type or a tuple of types, but got a list") has_tensor = False
if isinstance(base_type, tuple):
target_type = tuple(pytype_to_mstype[i] for i in base_type) def to_target_type(origin_type):
else: try:
target_type = (pytype_to_mstype[base_type],) if isinstance(origin_type, type):
if (isinstance(x_type, mstype.Bool) and mstype.Int in target_type) or \ ret_type = pytype_to_mstype[origin_type]
(isinstance(x_type, mstype.ref_type) and mstype.tensor_type in target_type): if ret_type == mstype.Int:
return True nonlocal has_int
return isinstance(x_type, target_type) has_int = True
except KeyError: if ret_type == mstype.tensor_type:
raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " nonlocal has_tensor
f"Tensor, Parameter, or a tuple containing only these types, but got {base_type}") has_tensor = True
return (ret_type,)
if isinstance(origin_type, tuple):
return tuple(to_target_type(i) for i in origin_type)
raise TypeError(f"The second arg of 'isinstance' must be a type or a tuple of types, "
f"but got a {type(origin_type).__name__}")
except KeyError:
raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, "
f"Tensor, Parameter, or a tuple containing only these types, but got {origin_type}")
target_type = to_target_type(base_type)
if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor):
return True
return isinstance(x_type, target_type)
@constexpr @constexpr

@ -43,10 +43,10 @@ def test_isinstance():
is_int = isinstance(self.int_member, int) is_int = isinstance(self.int_member, int)
is_float = isinstance(self.float_member, float) is_float = isinstance(self.float_member, float)
is_bool = isinstance(self.bool_member, bool) is_bool = isinstance(self.bool_member, bool)
bool_is_int = isinstance(self.bool_member, int) bool_is_int = isinstance(self.bool_member, (((int,)), float))
is_string = isinstance(self.string_member, str) is_string = isinstance(self.string_member, str)
is_parameter = isinstance(self.weight, Parameter) is_parameter = isinstance(self.weight, Parameter)
parameter_is_tensor = isinstance(self.weight, Tensor) parameter_is_tensor = isinstance(self.weight, ((Tensor, float), int))
is_tensor_const = isinstance(self.tensor_member, Tensor) is_tensor_const = isinstance(self.tensor_member, Tensor)
is_tensor_var = isinstance(x, Tensor) is_tensor_var = isinstance(x, Tensor)
is_tuple_const = isinstance(self.tuple_member, tuple) is_tuple_const = isinstance(self.tuple_member, tuple)
@ -88,8 +88,7 @@ def test_isinstance_not_supported():
net = Net() net = Net()
with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
net() net()
assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \ assert "The second arg of 'isinstance' must be a type or a tuple of types, but got a NoneType" in str(err.value)
"or a tuple containing only these types, but got None" in str(err.value)
def test_isinstance_second_arg_is_list(): def test_isinstance_second_arg_is_list():

Loading…
Cancel
Save