|
|
|
@ -85,13 +85,13 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
|
|
|
|
|
EXPECT_EQ(input_data_node_num, 3);
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
|
|
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
|
|
|
|
|
auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data");
|
|
|
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
|
|
|
|
|
auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data");
|
|
|
|
|
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data");
|
|
|
|
|
auto data2 = MakeNode(sub_graph, 1, 1, "data2", "Data");
|
|
|
|
|
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
|
|
|
|
@ -100,19 +100,28 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
|
|
|
|
|
|
|
|
|
|
sub_graph->SetParentNode(parent_case);
|
|
|
|
|
sub_graph->SetParentGraph(parent_graph);
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
|
|
|
|
|
// after pass, data1 and data2 are fused to data0
|
|
|
|
|
parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph);
|
|
|
|
|
size_t sub_graph_num = parent_graph->GetAllSubgraphs().size();
|
|
|
|
|
EXPECT_EQ(sub_graph_num, 1);
|
|
|
|
|
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1");
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr);
|
|
|
|
|
EXPECT_NE(data1_node, nullptr);
|
|
|
|
|
auto data2_node = sub_graph->FindNode("data2");
|
|
|
|
|
EXPECT_EQ(data2_node, nullptr);
|
|
|
|
|
EXPECT_NE(data2_node, nullptr);
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS);
|
|
|
|
|
|
|
|
|
|
// after pass, data1 and data2 are fused to data0
|
|
|
|
|
data1_node = sub_graph->FindNode("data1");
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr);
|
|
|
|
|
data2_node = sub_graph->FindNode("data2");
|
|
|
|
|
EXPECT_EQ(data2_node, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// graph with subgraph
|
|
|
|
|
/// const
|
|
|
|
|
/// / \
|
|
|
|
|
/// cast1 cast2
|
|
|
|
|
/// cast1 cast1
|
|
|
|
|
/// \ /
|
|
|
|
|
/// case
|
|
|
|
|
/// |
|
|
|
|
@ -127,7 +136,6 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
|
|
|
|
|
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph");
|
|
|
|
|
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const");
|
|
|
|
|
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast");
|
|
|
|
|
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast");
|
|
|
|
|
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case");
|
|
|
|
|
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput");
|
|
|
|
|
|
|
|
|
@ -136,23 +144,21 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
|
|
|
|
|
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc);
|
|
|
|
|
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0));
|
|
|
|
|
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0));
|
|
|
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0));
|
|
|
|
|
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
|
|
|
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0));
|
|
|
|
|
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
|
|
|
|
|
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
|
|
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
|
|
|
|
|
auto data0 = MakeNode(sub_graph, 1, 1, "data0", "Data");
|
|
|
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
|
|
|
|
|
auto data1 = MakeNode(sub_graph, 1, 1, "data1", "Data");
|
|
|
|
|
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
|
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
|
|
|
|
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
|
|
|
|
@ -160,9 +166,17 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
|
|
|
|
|
|
|
|
|
|
sub_graph->SetParentNode(parent_case);
|
|
|
|
|
sub_graph->SetParentGraph(parent_graph);
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
|
|
|
|
|
// after pass, data1 is fused to data0
|
|
|
|
|
parent_graph->AddSubgraph(sub_graph->GetName(), sub_graph);
|
|
|
|
|
|
|
|
|
|
size_t sub_graph_num = parent_graph->GetAllSubgraphs().size();
|
|
|
|
|
EXPECT_EQ(sub_graph_num, 1);
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1");
|
|
|
|
|
EXPECT_NE(data1_node, nullptr);
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(parent_graph), SUCCESS);
|
|
|
|
|
|
|
|
|
|
// after pass, data1 is fused to data0
|
|
|
|
|
data1_node = sub_graph->FindNode("data1");
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ge
|
|
|
|
|