From fbe66231780861834f0648038060376e26c9f28f Mon Sep 17 00:00:00 2001 From: yefeng Date: Wed, 31 Mar 2021 14:04:06 +0800 Subject: [PATCH] fix_if_subgraph --- .../lite/tools/converter/parser/tf/tf_model_parser.cc | 10 +++++++--- .../lite/tools/converter/parser/tf/tf_model_parser.h | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 5a05e93a0b..a0ac14f94e 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -699,13 +699,15 @@ STATUS TFModelParser::ConvertSubgraph() { // add while cond body function to while node input if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) { - if (sub_graph_name.find("cond") != std::string::npos) { + if (find(while_cond_branch_name_.begin(), while_cond_branch_name_.end(), sub_graph_name) != + while_cond_branch_name_.end()) { while_cond_map[cnode] = sub_func_graph; } else { while_body_map[cnode] = sub_func_graph; } } else { - if (sub_graph_name.find("true") != std::string::npos) { + if (find(if_then_branch_name_.begin(), if_then_branch_name_.end(), sub_graph_name) != + if_then_branch_name_.end()) { if_then_map[cnode] = sub_func_graph; } else { if_else_map[cnode] = sub_func_graph; @@ -914,13 +916,15 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { auto cond_name = attr_value.func().name(); function_while_map_[cond_name] = anf_node; + while_cond_branch_name_.push_back(cond_name); MS_LOG(DEBUG) << "parse cond name:" << cond_name; } - } else if (op_type == "StatelessIf") { + } else if (op_type == "StatelessIf" || op_type == "If") { MS_LOG(INFO) << "find if node:" << node_def.name(); tensorflow::AttrValue attr_value; if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) { auto then_name = attr_value.func().name(); + if_then_branch_name_.push_back(then_name); function_if_map_[then_name] = anf_node; MS_LOG(DEBUG) << "parse then name:" << then_name; } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 5a6b5fbb25..173158c4e0 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -90,6 +90,8 @@ class TFModelParser : public ModelParser { std::map function_while_map_; // tf function name->while_node_name std::map function_if_map_; // tf function name->if_node std::vector>> nodes_with_null_input_{}; + std::vector while_cond_branch_name_; + std::vector if_then_branch_name_; }; } // namespace lite } // namespace mindspore