|
|
@ -523,38 +523,39 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no
|
|
|
|
|
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
TypePtr type_ptr = node->Type();
|
|
|
|
auto get_single_type = [](const TypePtr &type_ptr) -> TypeId {
|
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
|
|
|
if (type_ptr->isa<TensorType>() && output_idx == 0) {
|
|
|
|
if (type_ptr->isa<TensorType>()) {
|
|
|
|
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
|
|
|
|
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
|
|
|
|
|
|
|
TypePtr elem = tensor_ptr->element();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(elem);
|
|
|
|
|
|
|
|
return elem->type_id();
|
|
|
|
|
|
|
|
} else if (type_ptr->isa<Tuple>()) {
|
|
|
|
|
|
|
|
auto tuple_ptr = type_ptr->cast<TuplePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
|
|
|
|
|
|
|
if (output_idx >= tuple_ptr->size()) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto tuple_i = (*tuple_ptr)[output_idx];
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_i);
|
|
|
|
|
|
|
|
if (tuple_i->isa<TensorType>()) {
|
|
|
|
|
|
|
|
auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
|
|
|
TypePtr elem = tensor_ptr->element();
|
|
|
|
TypePtr elem = tensor_ptr->element();
|
|
|
|
MS_EXCEPTION_IF_NULL(elem);
|
|
|
|
MS_EXCEPTION_IF_NULL(elem);
|
|
|
|
return elem->type_id();
|
|
|
|
return elem->type_id();
|
|
|
|
} else if (tuple_i->isa<Number>()) {
|
|
|
|
|
|
|
|
return tuple_i->type_id();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
|
|
|
|
|
|
|
|
return tuple_i->type_id();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (type_ptr->isa<Number>()) {
|
|
|
|
if (type_ptr->isa<Number>()) {
|
|
|
|
|
|
|
|
return type_ptr->type_id();
|
|
|
|
|
|
|
|
}
|
|
|
|
return type_ptr->type_id();
|
|
|
|
return type_ptr->type_id();
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
auto get_tuple_type = [get_single_type](const TypePtr &type_ptr, size_t output_idx) -> TypeId {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
|
|
|
|
|
|
|
if (!type_ptr->isa<Tuple>()) {
|
|
|
|
|
|
|
|
return get_single_type(type_ptr);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto tuple_ptr = type_ptr->cast<TuplePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
|
|
|
|
|
|
|
if (output_idx >= tuple_ptr->size()) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return get_single_type((*tuple_ptr)[output_idx]);
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
TypePtr type_ptr = node->Type();
|
|
|
|
|
|
|
|
if (type_ptr->isa<RefType>()) {
|
|
|
|
|
|
|
|
auto ref_type_ptr = type_ptr->cast<RefTypePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ref_type_ptr);
|
|
|
|
|
|
|
|
return get_tuple_type(ref_type_ptr->subtype(), output_idx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return type_ptr->type_id();
|
|
|
|
return get_tuple_type(type_ptr, output_idx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
|
|
|
|