|
|
@ -106,6 +106,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
TypeId max_type_id = kTypeUnknown;
|
|
|
|
TypeId max_type_id = kTypeUnknown;
|
|
|
|
size_t max_type_number = 0;
|
|
|
|
size_t max_type_number = 0;
|
|
|
|
bool has_int8 = false;
|
|
|
|
bool has_int8 = false;
|
|
|
|
|
|
|
|
bool has_scalar_int32 = false;
|
|
|
|
|
|
|
|
bool has_scalar_float32 = false;
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
TypeId arg_type = kTypeUnknown;
|
|
|
|
TypeId arg_type = kTypeUnknown;
|
|
|
@ -114,6 +116,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (arg_type != kObjectTypeTensorType) {
|
|
|
|
if (arg_type != kObjectTypeTensorType) {
|
|
|
|
|
|
|
|
if (arg_type_id == kNumberTypeInt32) {
|
|
|
|
|
|
|
|
has_scalar_int32 = true;
|
|
|
|
|
|
|
|
} else if (arg_type_id == kNumberTypeFloat32) {
|
|
|
|
|
|
|
|
has_scalar_float32 = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto it = type_map.find(arg_type_id);
|
|
|
|
auto it = type_map.find(arg_type_id);
|
|
|
@ -135,6 +142,17 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) {
|
|
|
|
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) {
|
|
|
|
max_type_id = kNumberTypeInt16;
|
|
|
|
max_type_id = kNumberTypeInt16;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// if bool is the max type, see if there is scalar input
|
|
|
|
|
|
|
|
// if so, it means that max is bool tensor, use scalar type instead.
|
|
|
|
|
|
|
|
// for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
|
|
|
|
|
|
|
|
if (max_type_id == kNumberTypeBool) {
|
|
|
|
|
|
|
|
if (has_scalar_int32) {
|
|
|
|
|
|
|
|
max_type_id = kNumberTypeInt32;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (has_scalar_float32) {
|
|
|
|
|
|
|
|
max_type_id = kNumberTypeFloat32;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
return max_type_id;
|
|
|
|
return max_type_id;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|