diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index c70cfe5d46..7e34026d1e 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -106,6 +106,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve TypeId max_type_id = kTypeUnknown; size_t max_type_number = 0; bool has_int8 = false; + bool has_scalar_int32 = false; + bool has_scalar_float32 = false; for (const auto &index : indices) { TypeId arg_type_id = kTypeUnknown; TypeId arg_type = kTypeUnknown; @@ -114,6 +116,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve continue; } if (arg_type != kObjectTypeTensorType) { + if (arg_type_id == kNumberTypeInt32) { + has_scalar_int32 = true; + } else if (arg_type_id == kNumberTypeFloat32) { + has_scalar_float32 = true; + } continue; } auto it = type_map.find(arg_type_id); @@ -135,6 +142,17 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve if (max_type_id == kNumberTypeUInt8 && has_int8 == true) { max_type_id = kNumberTypeInt16; } + // if bool is the max type, see if there is scalar input + // if so, it means that max is bool tensor, use scalar type instead. + // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) + if (max_type_id == kNumberTypeBool) { + if (has_scalar_int32) { + max_type_id = kNumberTypeInt32; + } + if (has_scalar_float32) { + max_type_id = kNumberTypeFloat32; + } + } return max_type_id; } diff --git a/tests/st/ops/ascend/test_autocast.py b/tests/st/ops/ascend/test_autocast.py index 448dc9b4d6..35690ce2c4 100644 --- a/tests/st/ops/ascend/test_autocast.py +++ b/tests/st/ops/ascend/test_autocast.py @@ -246,3 +246,21 @@ def test_tensor_auto_cast(): bnet(t_fp32) with pytest.raises(TypeError): bnet(t_fp64) +def test_bool_tensor_and_float(): + context.set_context(mode=context.GRAPH_MODE) + t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_) + t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32) + t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16) + t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32) + net = TensorFPAutoCast() + out = net(t_bool) + assert out.dtype == mstype.float32 + net = TensorIntAutoCast() + out = net(t_bool) + assert out.dtype == mstype.int32 + out = net(t_fp16) + assert out.dtype == mstype.float16 + out = net(t_fp32) + assert out.dtype == mstype.float32 + out = net(t_int32) + assert out.dtype == mstype.int32