From 4d9491b5c32ab509039f1f4e2f14fcb22db43a0c Mon Sep 17 00:00:00 2001 From: buxue Date: Sat, 6 Jun 2020 17:18:17 +0800 Subject: [PATCH] fix bug of auto cast when there is scalar --- .../ccsrc/operator/composite/do_signature.cc | 37 ++++--------------- 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index 7f33d4a3c7..283afe5d5b 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -65,21 +65,9 @@ void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &arg } } } -bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_number, const TypeId &scalar_type, - const size_t &s_type_number) { - if (scalar_type == kNumberTypeFloat16 || scalar_type == kNumberTypeFloat32 || scalar_type == kNumberTypeFloat64) { - if (tensor_type == kNumberTypeFloat16 || tensor_type == kNumberTypeFloat32 || tensor_type == kNumberTypeFloat64) { - return t_type_number >= s_type_number; - } - return false; - } - return true; -} -void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, - const size_t type_number) { +void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { *max_type_id = type_id; - *max_type = type; *max_type_number = type_number; } @@ -118,7 +106,6 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, const std::set &write_indices) { TypeId max_type_id = kTypeUnknown; - TypeId max_type = kTypeUnknown; size_t max_type_number = 0; bool has_int8 = false; for (const auto &index : indices) { @@ -128,6 +115,9 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { continue; } + if (arg_type != kObjectTypeTensorType) { + continue; + } auto it = type_map.find(arg_type_id); if (it == type_map.end()) { continue; @@ -136,24 +126,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve has_int8 = true; } if (max_type_id == kTypeUnknown) { - SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); + SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); continue; } - - if (max_type == arg_type) { - if (it->second > max_type_number) { - SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); - } - } else { - if (arg_type == kObjectTypeTensorType) { - if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { - SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); - } - } else { - if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { - SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); - } - } + if (it->second > max_type_number) { + SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); } }