diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index 22a2ad7b9c..d1326ac943 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -96,7 +96,7 @@ TypeId GetMaxTypeId(const std::vector &input_types, std::vector TypeId max_type_id = kTypeUnknown; size_t max_type_number = 0; bool has_int8 = false; - bool has_scalar_int32 = false; + bool has_scalar_int64 = false; bool has_scalar_float32 = false; for (const auto &index : indices) { TypeId arg_type_id = kTypeUnknown; @@ -105,8 +105,8 @@ TypeId GetMaxTypeId(const std::vector &input_types, std::vector continue; } if (arg_type != kObjectTypeTensorType) { - if (arg_type_id == kNumberTypeInt32) { - has_scalar_int32 = true; + if (arg_type_id == kNumberTypeInt64) { + has_scalar_int64 = true; } else if (arg_type_id == kNumberTypeFloat32) { has_scalar_float32 = true; } @@ -135,8 +135,8 @@ TypeId GetMaxTypeId(const std::vector &input_types, std::vector // if so, it means that max is bool tensor, use scalar type instead. // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) if (max_type_id == kNumberTypeBool) { - if (has_scalar_int32) { - max_type_id = kNumberTypeInt32; + if (has_scalar_int64) { + max_type_id = kNumberTypeInt64; } if (has_scalar_float32) { max_type_id = kNumberTypeFloat32; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index fe0ed330c1..12b9d54b72 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -181,15 +181,15 @@ std::map GetDstType(const py::tuple &py_args, } size_t priority = 0; TypeId max_type = TypeId::kTypeUnknown; - bool has_float = false; - bool has_int = false; - bool has_int8 = false; + bool has_scalar_float32 = false; + bool has_scalar_int64 = false; + bool has_tensor_int8 = false; for (size_t index : indexes) { - if (!has_float && py::isinstance(py_args[index])) { - has_float = true; + if (!has_scalar_float32 && py::isinstance(py_args[index])) { + has_scalar_float32 = true; } - if (!has_int && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { - has_int = true; + if (!has_scalar_int64 && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { + has_scalar_int64 = true; } auto obj = py_args[index]; @@ -201,7 +201,7 @@ std::map GetDstType(const py::tuple &py_args, continue; } if (arg_type_id == kNumberTypeInt8) { - has_int8 = true; + has_tensor_int8 = true; } if (type_priority->second > priority) { max_type = type_priority->first; @@ -210,18 +210,18 @@ std::map GetDstType(const py::tuple &py_args, } } if (max_type == TypeId::kNumberTypeBool) { - if (has_int) { + if (has_scalar_int64) { max_type = TypeId::kNumberTypeInt64; } - if (has_float) { + if (has_scalar_float32) { max_type = TypeId::kNumberTypeFloat32; } } if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && - max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_float) { + max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) { max_type = TypeId::kNumberTypeFloat32; } - if (max_type == TypeId::kNumberTypeUInt8 && has_int8) { + if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) { max_type = TypeId::kNumberTypeInt16; } (void)dst_type.emplace(std::make_pair(type, max_type)); diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index dc641d0321..bad2ab4429 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1780,8 +1780,7 @@ class TopK(PrimitiveWithInfer): >>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16) >>> k = 3 >>> values, indices = topk(input_x, k) - >>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16).all() - >>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32).all() + ([5.0, 4.0, 3.0], [4, 3, 2]) """ @prim_attr_register