|
|
|
@ -58,8 +58,6 @@ public:
|
|
|
|
|
/// netoutput
|
|
|
|
|
/// ...
|
|
|
|
|
/// data0 data1 data2
|
|
|
|
|
/// | \ /
|
|
|
|
|
/// conv add
|
|
|
|
|
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
|
|
|
|
|
PassManager pass_manager;
|
|
|
|
|
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass);
|
|
|
|
@ -81,6 +79,11 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
|
|
|
|
|
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2));
|
|
|
|
|
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));
|
|
|
|
|
|
|
|
|
|
auto case_node = parent_graph->FindNode("parent_case");
|
|
|
|
|
EXPECT_NE(case_node, nullptr);
|
|
|
|
|
size_t input_data_node_num = case_node->GetInDataNodes().size();
|
|
|
|
|
EXPECT_EQ(input_data_node_num, 3);
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
|
|
|
|
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
|
|
|
|
|
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
|
|
|
@ -98,6 +101,12 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
|
|
|
|
|
sub_graph->SetParentNode(parent_case);
|
|
|
|
|
sub_graph->SetParentGraph(parent_graph);
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
|
|
|
|
|
// after pass, data1 and data2 are fused to data0
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1");
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr);
|
|
|
|
|
auto data2_node = sub_graph->FindNode("data2");
|
|
|
|
|
EXPECT_EQ(data2_node, nullptr);
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// graph with subgraph
|
|
|
|
@ -152,5 +161,8 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
|
|
|
|
|
sub_graph->SetParentNode(parent_case);
|
|
|
|
|
sub_graph->SetParentGraph(parent_graph);
|
|
|
|
|
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
|
|
|
|
|
// after pass, data1 is fused to data0
|
|
|
|
|
auto data1_node = sub_graph->FindNode("data1");
|
|
|
|
|
EXPECT_EQ(data1_node, nullptr);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ge
|
|
|
|
|