diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 000d411d18..03c940f1b9 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -18,7 +18,7 @@ from dataclasses import dataclass -from mindspore import Tensor +from mindspore import Tensor, Parameter from mindspore import dtype as mstype from ..._checkparam import Validator as validator @@ -361,10 +361,13 @@ def check_type_same(x_type, base_type): str: mstype.String, list: mstype.List, tuple: mstype.Tuple, - Tensor: mstype.tensor_type + Tensor: mstype.tensor_type, + Parameter: mstype.ref_type } try: - if isinstance(base_type, (tuple, list)): + if isinstance(base_type, list): + raise TypeError("The second arg of 'isinstance' must be a type or a tuple of types, but got a list") + if isinstance(base_type, tuple): target_type = tuple(pytype_to_mstype[i] for i in base_type) else: target_type = pytype_to_mstype[base_type] diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 338ebb0227..e3b8e7ae76 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -111,6 +111,7 @@ none_type = typing.TypeNone env_type_type = typing.EnvType tensor_type = typing.TensorType anything_type = typing.TypeAnything +ref_type = typing.RefType number_type = (int8, int16, diff --git a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index 067e0cd013..e72cdf2ad3 100644 --- a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -26,7 +26,7 @@ using ".register" decorator. @zeros_like_leaf.register("Number") -def _zeros_like_scala(x): +def _zeros_like_scalar(x): """Returns 0 which has the same dtype as x where x is a scalar.""" return 0 diff --git a/tests/ut/python/pipeline/parse/test_isinstance.py b/tests/ut/python/pipeline/parse/test_isinstance.py index beaca638b3..4507293bb0 100644 --- a/tests/ut/python/pipeline/parse/test_isinstance.py +++ b/tests/ut/python/pipeline/parse/test_isinstance.py @@ -17,7 +17,7 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore import Tensor +from mindspore import Tensor, Parameter from mindspore import context context.set_context(mode=context.GRAPH_MODE, save_graphs=True) @@ -34,12 +34,14 @@ def test_isinstance(): self.tensor_member = Tensor(np.arange(4)) self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) self.list_member = list(self.tuple_member) + self.weight = Parameter(1.0) def construct(self, x, y): is_int = isinstance(self.int_member, int) is_float = isinstance(self.float_member, float) is_bool = isinstance(self.bool_member, bool) is_string = isinstance(self.string_member, str) + is_parameter = isinstance(self.weight, Parameter) is_tensor_const = isinstance(self.tensor_member, Tensor) is_tensor_var = isinstance(x, Tensor) is_tuple_const = isinstance(self.tuple_member, tuple) @@ -52,7 +54,7 @@ def test_isinstance(): bool_is_string = isinstance(self.bool_member, str) tensor_is_tuple = isinstance(x, tuple) tuple_is_list = isinstance(self.tuple_member, list) - return is_int, is_float, is_bool, is_string, is_tensor_const, is_tensor_var, \ + return is_int, is_float, is_bool, is_string, is_parameter, is_tensor_const, is_tensor_var, \ is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list @@ -60,7 +62,7 @@ def test_isinstance(): net = Net() x = Tensor(np.arange(4)) y = Tensor(np.arange(5)) - assert net(x, y) == (True,) * 12 + (False,) * 4 + assert net(x, y) == (True,) * 13 + (False,) * 4 def test_isinstance_not_supported(): @@ -76,3 +78,18 @@ def test_isinstance_not_supported(): with pytest.raises(TypeError) as err: net() assert "The type 'None' is not supported for 'isinstance'" in str(err.value) + + +def test_isinstance_second_arg_is_list(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = (11, 22, 33, 44) + + def construct(self): + return isinstance(self.value, [tuple, int, float]) + + net = Net() + with pytest.raises(TypeError) as err: + net() + assert "The second arg of 'isinstance' must be a type or a tuple of types, but got a list" in str(err.value)