|
|
@ -258,51 +258,75 @@ TEST_F(UtestGeHybrid, init_weight_success) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
|
|
|
|
TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
|
|
|
|
ComputeGraphPtr merged_graph = nullptr;
|
|
|
|
ComputeGraphPtr merged_graph = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
|
|
|
|
ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
|
|
|
|
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
|
|
|
|
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
|
|
|
|
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);
|
|
|
|
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);
|
|
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
|
|
|
|
ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
|
|
|
|
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT);
|
|
|
|
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT);
|
|
|
|
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/
|
|
|
|
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/
|
|
|
|
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
|
|
|
|
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
|
|
|
|
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
|
|
|
|
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
|
|
|
|
sub_sub_graph2->SetGraphUnknownFlag(true);
|
|
|
|
sub_sub_graph2->SetGraphUnknownFlag(true);
|
|
|
|
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD);
|
|
|
|
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD);
|
|
|
|
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node);
|
|
|
|
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node);
|
|
|
|
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node);
|
|
|
|
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node);
|
|
|
|
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/
|
|
|
|
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/
|
|
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
|
|
|
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE);
|
|
|
|
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE);
|
|
|
|
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
|
|
|
|
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
|
|
|
|
sub_graph->SetGraphUnknownFlag(true);
|
|
|
|
sub_graph->SetGraphUnknownFlag(true);
|
|
|
|
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond");
|
|
|
|
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond");
|
|
|
|
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body");
|
|
|
|
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body");
|
|
|
|
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond");
|
|
|
|
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond");
|
|
|
|
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body");
|
|
|
|
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body");
|
|
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
|
|
|
|
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
|
|
|
|
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL);
|
|
|
|
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL);
|
|
|
|
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
|
|
|
|
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
|
|
|
|
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph");
|
|
|
|
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph");
|
|
|
|
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph");
|
|
|
|
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph");
|
|
|
|
|
|
|
|
|
|
|
|
root_graph->AddSubGraph(sub_sub_graph1);
|
|
|
|
root_graph->AddSubGraph(sub_sub_graph1);
|
|
|
|
root_graph->AddSubGraph(sub_sub_graph2);
|
|
|
|
root_graph->AddSubGraph(sub_sub_graph2);
|
|
|
|
sub_sub_graph1->SetParentGraph(root_graph);
|
|
|
|
sub_sub_graph1->SetParentGraph(root_graph);
|
|
|
|
sub_sub_graph2->SetParentGraph(root_graph);
|
|
|
|
sub_sub_graph2->SetParentGraph(root_graph);
|
|
|
|
sub_sub_graph1->SetParentNode(sub_graph_while_node);
|
|
|
|
sub_sub_graph1->SetParentNode(sub_graph_while_node);
|
|
|
|
sub_sub_graph2->SetParentNode(sub_graph_while_node);
|
|
|
|
sub_sub_graph2->SetParentNode(sub_graph_while_node);
|
|
|
|
|
|
|
|
|
|
|
|
root_graph->AddSubGraph(sub_graph);
|
|
|
|
root_graph->AddSubGraph(sub_graph);
|
|
|
|
sub_graph->SetParentNode(partitioned_call_node);
|
|
|
|
sub_graph->SetParentNode(partitioned_call_node);
|
|
|
|
sub_graph->SetParentGraph(root_graph);
|
|
|
|
sub_graph->SetParentGraph(root_graph);
|
|
|
|
|
|
|
|
|
|
|
|
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
|
|
|
|
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
|
|
|
|
HybridModel hybrid_model(root_model);
|
|
|
|
HybridModel hybrid_model(root_model);
|
|
|
|
HybridModelBuilder hybrid_model_builder(hybrid_model);
|
|
|
|
HybridModelBuilder hybrid_model_builder(hybrid_model);
|
|
|
|
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
|
|
|
|
|
|
|
|
|
|
|
|
// subgraph num before unfold: 1
|
|
|
|
|
|
|
|
EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 3);
|
|
|
|
|
|
|
|
// num of nodes in root_graph before unfold: 1, name: partitioned_call
|
|
|
|
|
|
|
|
EXPECT_EQ(root_graph->GetDirectNodesSize(), 1);
|
|
|
|
|
|
|
|
EXPECT_EQ(root_graph->GetDirectNode().at(0)->GetName(), "partitioned_call");
|
|
|
|
|
|
|
|
// two sub_sub_graphs: while cond & while body, their parent graph is "subgraph" before unfold
|
|
|
|
|
|
|
|
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "root_graph");
|
|
|
|
|
|
|
|
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "root_graph");
|
|
|
|
|
|
|
|
// node "cond_data" & "body_data" has owner compute graph "subgraph" before unfold
|
|
|
|
|
|
|
|
EXPECT_EQ(sub_graph_while_node->GetOwnerComputeGraph()->GetName(), "sub_graph");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// unfold success
|
|
|
|
|
|
|
|
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// subgraph num after unfold: 0
|
|
|
|
|
|
|
|
EXPECT_EQ(merged_graph->GetAllSubgraphs().size(), 2);
|
|
|
|
|
|
|
|
// num of nodes in MergedGraph after unfold: 1, name: while
|
|
|
|
|
|
|
|
EXPECT_EQ(merged_graph->GetDirectNodesSize(), 1);
|
|
|
|
|
|
|
|
EXPECT_EQ(merged_graph->GetDirectNode().at(0)->GetName(), "while");
|
|
|
|
|
|
|
|
// two sub_sub_graphs: while cond & while body, their parent graph is "MergedGraph" after unfold
|
|
|
|
|
|
|
|
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "MergedGraph" );
|
|
|
|
|
|
|
|
EXPECT_EQ(sub_sub_graph1->GetParentGraph()->GetName(), "MergedGraph");
|
|
|
|
|
|
|
|
// node "cond_data" & "body_data" has owner compute graph "MergedGraph" before unfold
|
|
|
|
|
|
|
|
EXPECT_EQ(sub_graph_while_node->GetOwnerComputeGraph()->GetName(), "MergedGraph");
|
|
|
|
}
|
|
|
|
}
|
|
|
|