diff --git a/mindspore/lite/test/models_tflite.cfg b/mindspore/lite/test/models_tflite.cfg index 9a3861ef47..648563278e 100644 --- a/mindspore/lite/test/models_tflite.cfg +++ b/mindspore/lite/test/models_tflite.cfg @@ -48,4 +48,6 @@ scan_hms_angle1.tflite hiai_latin_ocr.tflite hiai_latin_ocr_1.tflite ml_ocr_jk.tflite +nasnet_mobile.tflite +nasnet_large.tflite inception_resnet_v2.tflite diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index df0f980822..119cec8bba 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -15,6 +15,8 @@ */ #include "tools/optimizer/common/gllo_utils.h" #include +#include +#include #include "src/ir/primitive_t_value.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" @@ -367,5 +369,29 @@ size_t GetOutputTensorNum(const AnfNodePtr &node) { return 1; } } + +bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) { + auto output_node_list = GetRealNodeUsedList(graph, node); + if (output_node_list->size() != 1) { + MS_LOG(DEBUG) << "fusion node has multi output nodes"; + return true; + } + return false; +} + +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node) { + auto output_node_list = std::make_shared>>(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto iter = manager->node_users().find(node); + if (iter == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + auto output_info_list = iter->second; + std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list)); + return output_node_list; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 224bac3ceb..9827877be1 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -59,6 +59,8 @@ bool IsConvNode(const BaseRef &n); bool CheckIsAllInputsParam(const AnfNodePtr &node); size_t GetOutputTensorNum(const AnfNodePtr &node); + +bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index dcd2fc90d6..14d1b3972f 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -51,6 +51,9 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c AnfNodePtr pre_node = act_node->input(1); CheckIfAnfNodeIsNull(pre_node); if (pre_node != nullptr && pre_node->isa()) { + if (IsMultiOutputTensors(func_graph, pre_node)) { + return node; + } auto conv_node = pre_node->cast(); auto node_type = GetCNodeType(conv_node); auto primitiveT_value = GetValueNode>(conv_node->input(0)); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index c7ebce2232..4b47a9f469 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -138,6 +138,9 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons AnfNodePtr conv_node_anf = add_node->input(1); CheckIfAnfNodeIsNull(conv_node_anf); + if (IsMultiOutputTensors(func_graph, conv_node_anf)) { + return add_node; + } auto conv_node = conv_node_anf->cast(); CheckIfCNodeIsNull(conv_node); GenConvNewBias(func_graph, conv_node, add_node); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 3132618377..a5c09287a3 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -66,7 +66,9 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co auto pre_node = transform_node->input(1); auto conv_node = pre_node->cast(); - + if (IsMultiOutputTensors(func_graph, conv_node)) { + return transform_node; + } int kernel_nums = Get_Kenrnel_nums(conv_node); if (kernel_nums <= 0) { MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString();