|
|
|
@ -65,21 +65,9 @@ void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &arg
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_number, const TypeId &scalar_type,
|
|
|
|
|
const size_t &s_type_number) {
|
|
|
|
|
if (scalar_type == kNumberTypeFloat16 || scalar_type == kNumberTypeFloat32 || scalar_type == kNumberTypeFloat64) {
|
|
|
|
|
if (tensor_type == kNumberTypeFloat16 || tensor_type == kNumberTypeFloat32 || tensor_type == kNumberTypeFloat64) {
|
|
|
|
|
return t_type_number >= s_type_number;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
|
|
|
|
|
const size_t type_number) {
|
|
|
|
|
void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) {
|
|
|
|
|
*max_type_id = type_id;
|
|
|
|
|
*max_type = type;
|
|
|
|
|
*max_type_number = type_number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -118,7 +106,6 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
|
|
|
|
|
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
|
|
|
|
|
const std::set<size_t> &write_indices) {
|
|
|
|
|
TypeId max_type_id = kTypeUnknown;
|
|
|
|
|
TypeId max_type = kTypeUnknown;
|
|
|
|
|
size_t max_type_number = 0;
|
|
|
|
|
bool has_int8 = false;
|
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
@ -128,6 +115,9 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
|
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (arg_type != kObjectTypeTensorType) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto it = type_map.find(arg_type_id);
|
|
|
|
|
if (it == type_map.end()) {
|
|
|
|
|
continue;
|
|
|
|
@ -136,24 +126,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|
|
|
|
has_int8 = true;
|
|
|
|
|
}
|
|
|
|
|
if (max_type_id == kTypeUnknown) {
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (max_type == arg_type) {
|
|
|
|
|
if (it->second > max_type_number) {
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (arg_type == kObjectTypeTensorType) {
|
|
|
|
|
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
|
|
|
|
|
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (it->second > max_type_number) {
|
|
|
|
|
SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|