diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc index 5465c1aa5a..b00eafba04 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -27,6 +27,34 @@ namespace mindspore { namespace opt { namespace { +int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input, + std::vector *plant_inputs) { + if (!AnfAlgo::IsTupleOutput(tuple_input)) { + auto abs = tuple_input->abstract(); + MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString(); + return -1; + } + MS_EXCEPTION_IF_NULL(plant_inputs); + auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input); + if (tuple_input->isa() && AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) { + auto make_tuple = tuple_input->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple); + for (size_t j = 0; j < tuple_input_num; ++j) { + // using for graph kernel + auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); + MS_EXCEPTION_IF_NULL(dyn_input_node); + plant_inputs->emplace_back(dyn_input_node); + } + return input_size; + } + for (size_t index = 0; index < input_size; ++index) { + auto dyn_input_node = CreatTupleGetItemNode(graph, tuple_input, index); + plant_inputs->emplace_back(dyn_input_node); + } + return input_size; +} + void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(graph); @@ -41,25 +69,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt for (size_t i = 0; i < input_num; ++i) { auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i); MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { - auto input_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes.push_back(input_size); - auto make_tuple = input_node->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple); - for (size_t j = 0; j < tuple_input_num; ++j) { - auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); - MS_EXCEPTION_IF_NULL(dyn_input_node); - if (IsValueNode(dyn_input_node)) { - auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto success = kernel_graph->NewValueNode(dyn_input_node->cast()); - if (!success) { - MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString(); - } - } - plant_inputs.push_back(dyn_input_node); - } + if (AnfAlgo::IsTupleOutput(input_node)) { + dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs)); } else { dyn_input_sizes.push_back(-1); plant_inputs.push_back(input_node); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index f8efe4072f..2e1d274179 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1156,6 +1156,9 @@ uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { MS_EXCEPTION_IF_NULL(anf); TypePtr type = anf->Type(); + if (type == nullptr) { + return false; + } MS_EXCEPTION_IF_NULL(type); return type->isa(); }