diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 7a33c9ceed..796933c6f1 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -276,7 +276,7 @@ def check_value_elements(data_dtype, types): else: raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") - elif mstype.issubclass_(ele, data_dtype): + elif mstype.dtype_to_pytype(ele) == mstype.dtype_to_pytype(data_dtype): scalars_number += 1 else: raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index b6a261d292..7985b9c232 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -278,8 +278,8 @@ class TensorSetItemByMixedTensors_1(Cell): class TensorSetItemByMixedTensors_2(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_2, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) - self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16), name="x") self.value = value @@ -911,7 +911,7 @@ test_cases = [ Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], }), ('TensorSetItemByMixedTensorsWithTensor_2', { - 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))), + 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float16))), 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], @@ -923,9 +923,9 @@ test_cases = [ Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], }), ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', { - 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)), - Tensor(np.zeros((4, 5), np.float32)), - Tensor(np.ones((4, 5), np.float32)))), + 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float16)), + Tensor(np.zeros((4, 5), np.float16)), + Tensor(np.ones((4, 5), np.float16)))), 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],