diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 60c0e883..f5a802a2 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -258,51 +258,75 @@ TEST_F(UtestGeHybrid, init_weight_success) { } TEST_F(UtestGeHybrid, unfold_subgraphs_success) { - ComputeGraphPtr merged_graph = nullptr; - - ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); - OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); - NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); - - ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); - /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); - NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ - OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); - NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); - sub_sub_graph2->SetGraphUnknownFlag(true); - /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); - NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ - - ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); - NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(true); - sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); - sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); - sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); - sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); - - ComputeGraphPtr root_graph = std::make_shared("root_graph"); - auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); - auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); - partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); - partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - - root_graph->AddSubGraph(sub_sub_graph1); - root_graph->AddSubGraph(sub_sub_graph2); - sub_sub_graph1->SetParentGraph(root_graph); - sub_sub_graph2->SetParentGraph(root_graph); - sub_sub_graph1->SetParentNode(sub_graph_while_node); - sub_sub_graph2->SetParentNode(sub_graph_while_node); - - root_graph->AddSubGraph(sub_graph); - sub_graph->SetParentNode(partitioned_call_node); - sub_graph->SetParentGraph(root_graph); - - GeRootModelPtr root_model = MakeShared(root_graph); - HybridModel hybrid_model(root_model); - HybridModelBuilder hybrid_model_builder(hybrid_model); - EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); +ComputeGraphPtr merged_graph = nullptr; + +ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); +OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); +NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + +ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); +/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); +NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ +OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); +NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); +sub_sub_graph2->SetGraphUnknownFlag(true); +/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); +NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); +sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); +sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ + +ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); +OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); +NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); +sub_graph->SetGraphUnknownFlag(true); +sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); +sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); +sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); +sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); + +ComputeGraphPtr root_graph = std::make_shared("root_graph"); +auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); +auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); +partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); +partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); + +root_graph->AddSubGraph(sub_sub_graph1); +root_graph->AddSubGraph(sub_sub_graph2); +sub_sub_graph1->SetParentGraph(root_graph); +sub_sub_graph2->SetParentGraph(root_graph); +sub_sub_graph1->SetParentNode(sub_graph_while_node); +sub_sub_graph2->SetParentNode(sub_graph_while_node); + +root_graph->AddSubGraph(sub_graph); +sub_graph->SetParentNode(partitioned_call_node); +sub_graph->SetParentGraph(root_graph); + +GeRootModelPtr root_model = MakeShared(root_graph); +HybridModel hybrid_model(root_model); +HybridModelBuilder hybrid_model_builder(hybrid_model); + +// subgraph num before unfold: 1 +EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 3); +// num of nodes in root_graph before unfold: 1, name: partitioned_call +EXPECT_EQ(root_graph->GetDirectNodesSize(), 1); +EXPECT_EQ(root_graph->GetDirectNode().at(0)->GetName(), "partitioned_call"); +// two sub_sub_graphs: while cond & while body, their parent graph is "subgraph" before unfold +EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "root_graph"); +EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "root_graph"); +// node "cond_data" & "body_data" has owner compute graph "subgraph" before unfold +EXPECT_EQ(sub_graph_while_node->GetOwnerComputeGraph()->GetName(), "sub_graph"); + +// unfold success +EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); + +// subgraph num after unfold: 0 +EXPECT_EQ(merged_graph->GetAllSubgraphs().size(), 2); +// num of nodes in MergedGraph after unfold: 1, name: while +EXPECT_EQ(merged_graph->GetDirectNodesSize(), 1); +EXPECT_EQ(merged_graph->GetDirectNode().at(0)->GetName(), "while"); +// two sub_sub_graphs: while cond & while body, their parent graph is "MergedGraph" after unfold +EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "MergedGraph" ); +EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "MergedGraph"); +// node "cond_data" & "body_data" has owner compute graph "MergedGraph" before unfold +EXPECT_EQ(sub_graph_while_node->GetOwnerComputeGraph()->GetName(), "MergedGraph"); }