|
|
|
@ -68,8 +68,7 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
|
|
|
|
|
*max_type_number = type_number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
|
|
|
|
|
TypeId *arg_type = nullptr) {
|
|
|
|
|
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) {
|
|
|
|
|
if (arg_type_origin->isa<TensorType>()) {
|
|
|
|
|
auto tensor = arg_type_origin->cast<TensorTypePtr>();
|
|
|
|
|
auto tensor_type = tensor->element();
|
|
|
|
@ -102,8 +101,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>
|
|
|
|
|
for (const auto &index : indices) {
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
TypeId arg_type = kTypeUnknown;
|
|
|
|
|
auto is_write = (write_indices.find(index) != write_indices.end());
|
|
|
|
|
if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
|
|
|
|
|
if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (arg_type != kObjectTypeTensorType) {
|
|
|
|
@ -144,6 +142,10 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>
|
|
|
|
|
max_type_id = kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
|
|
|
|
|
has_scalar_float32) {
|
|
|
|
|
max_type_id = kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
return max_type_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -218,7 +220,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|
|
|
|
|
|
|
|
|
TypeId arg_type_id = kTypeUnknown;
|
|
|
|
|
auto arg_value = input_types[i];
|
|
|
|
|
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
|
|
|
|
(void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id);
|
|
|
|
|
auto it_map = type_name_map.find(arg_type_id);
|
|
|
|
|
if (it_map == type_name_map.end()) {
|
|
|
|
|
continue;
|
|
|
|
|