!14458 [MS_LITE] fix if op

From: @YeFeng_24
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
pull/14458/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d514f891ad

@ -699,13 +699,15 @@ STATUS TFModelParser::ConvertSubgraph() {
// add while cond body function to while node input
if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) {
if (sub_graph_name.find("cond") != std::string::npos) {
if (find(while_cond_branch_name_.begin(), while_cond_branch_name_.end(), sub_graph_name) !=
while_cond_branch_name_.end()) {
while_cond_map[cnode] = sub_func_graph;
} else {
while_body_map[cnode] = sub_func_graph;
}
} else {
if (sub_graph_name.find("true") != std::string::npos) {
if (find(if_then_branch_name_.begin(), if_then_branch_name_.end(), sub_graph_name) !=
if_then_branch_name_.end()) {
if_then_map[cnode] = sub_func_graph;
} else {
if_else_map[cnode] = sub_func_graph;
@ -914,13 +916,15 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
auto cond_name = attr_value.func().name();
function_while_map_[cond_name] = anf_node;
while_cond_branch_name_.push_back(cond_name);
MS_LOG(DEBUG) << "parse cond name:" << cond_name;
}
} else if (op_type == "StatelessIf") {
} else if (op_type == "StatelessIf" || op_type == "If") {
MS_LOG(INFO) << "find if node:" << node_def.name();
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) {
auto then_name = attr_value.func().name();
if_then_branch_name_.push_back(then_name);
function_if_map_[then_name] = anf_node;
MS_LOG(DEBUG) << "parse then name:" << then_name;
}

@ -90,6 +90,8 @@ class TFModelParser : public ModelParser {
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
std::map<std::string, AnfNodePtr> function_if_map_; // tf function name->if_node
std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
std::vector<std::string> while_cond_branch_name_;
std::vector<std::string> if_then_branch_name_;
};
} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save