|
|
|
@ -699,13 +699,15 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|
|
|
|
|
|
|
|
|
// add while cond body function to while node input
|
|
|
|
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
|
|
|
|
if (sub_graph_name.find("cond") != std::string::npos) {
|
|
|
|
|
if (find(while_cond_branch_name_.begin(), while_cond_branch_name_.end(), sub_graph_name) !=
|
|
|
|
|
while_cond_branch_name_.end()) {
|
|
|
|
|
while_cond_map[cnode] = sub_func_graph;
|
|
|
|
|
} else {
|
|
|
|
|
while_body_map[cnode] = sub_func_graph;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (sub_graph_name.find("true") != std::string::npos) {
|
|
|
|
|
if (find(if_then_branch_name_.begin(), if_then_branch_name_.end(), sub_graph_name) !=
|
|
|
|
|
if_then_branch_name_.end()) {
|
|
|
|
|
if_then_map[cnode] = sub_func_graph;
|
|
|
|
|
} else {
|
|
|
|
|
if_else_map[cnode] = sub_func_graph;
|
|
|
|
@ -914,13 +916,15 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|
|
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
|
|
|
|
|
auto cond_name = attr_value.func().name();
|
|
|
|
|
function_while_map_[cond_name] = anf_node;
|
|
|
|
|
while_cond_branch_name_.push_back(cond_name);
|
|
|
|
|
MS_LOG(DEBUG) << "parse cond name:" << cond_name;
|
|
|
|
|
}
|
|
|
|
|
} else if (op_type == "StatelessIf") {
|
|
|
|
|
} else if (op_type == "StatelessIf" || op_type == "If") {
|
|
|
|
|
MS_LOG(INFO) << "find if node:" << node_def.name();
|
|
|
|
|
tensorflow::AttrValue attr_value;
|
|
|
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) {
|
|
|
|
|
auto then_name = attr_value.func().name();
|
|
|
|
|
if_then_branch_name_.push_back(then_name);
|
|
|
|
|
function_if_map_[then_name] = anf_node;
|
|
|
|
|
MS_LOG(DEBUG) << "parse then name:" << then_name;
|
|
|
|
|
}
|
|
|
|
|