diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index ab40932c71..9cd5c1b1a8 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -523,38 +523,39 @@ std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); - TypePtr type_ptr = node->Type(); - MS_EXCEPTION_IF_NULL(type_ptr); - if (type_ptr->isa() && output_idx == 0) { - auto tensor_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem->type_id(); - } else if (type_ptr->isa()) { - auto tuple_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tuple_ptr); - if (output_idx >= tuple_ptr->size()) { - MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); - } - auto tuple_i = (*tuple_ptr)[output_idx]; - MS_EXCEPTION_IF_NULL(tuple_i); - if (tuple_i->isa()) { - auto tensor_ptr = tuple_i->cast(); + auto get_single_type = [](const TypePtr &type_ptr) -> TypeId { + MS_EXCEPTION_IF_NULL(type_ptr); + if (type_ptr->isa()) { + auto tensor_ptr = type_ptr->cast(); MS_EXCEPTION_IF_NULL(tensor_ptr); TypePtr elem = tensor_ptr->element(); MS_EXCEPTION_IF_NULL(elem); return elem->type_id(); - } else if (tuple_i->isa()) { - return tuple_i->type_id(); - } else { - MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); - return tuple_i->type_id(); } - } else if (type_ptr->isa()) { + if (type_ptr->isa()) { + return type_ptr->type_id(); + } return type_ptr->type_id(); + }; + auto get_tuple_type = [get_single_type](const TypePtr &type_ptr, size_t output_idx) -> TypeId { + MS_EXCEPTION_IF_NULL(type_ptr); + if (!type_ptr->isa()) { + return get_single_type(type_ptr); + } + auto tuple_ptr = type_ptr->cast(); + MS_EXCEPTION_IF_NULL(tuple_ptr); + if (output_idx >= tuple_ptr->size()) { + MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); + } + return get_single_type((*tuple_ptr)[output_idx]); + }; + TypePtr type_ptr = node->Type(); + if (type_ptr->isa()) { + auto ref_type_ptr = type_ptr->cast(); + MS_EXCEPTION_IF_NULL(ref_type_ptr); + return get_tuple_type(ref_type_ptr->subtype(), output_idx); } - return type_ptr->type_id(); + return get_tuple_type(type_ptr, output_idx); } TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 94a89ff320..ef1f5f9f47 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -414,12 +414,6 @@ void AscendControlParser::ChildGraphDataAssign( << node->DebugString(5) << " gives " << args.size(); } for (size_t i = 0; i < args.size(); ++i) { - if (args[i]->isa() && memo->find(child_graph) == memo->end()) { - MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString() - << " should be reused, continue."; - link_list->emplace_back(args[i], params[i]); - continue; - } InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); } }