From f203c70cfdaefd8f9b750ed713dcc746e80598d4 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Mon, 15 Mar 2021 14:59:46 +0800 Subject: [PATCH] modified: tests/ut/ge/graph/passes/base_pass_unittest.cc --- .../ut/ge/graph/passes/base_pass_unittest.cc | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index 56a7077a..b1934359 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass { names_to_add_repass_.erase(iter); } } + // simulate infershape pass + if(node->GetType() == WHILE){ + bool need_repass = false; + AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); + if(!OptionExists(kOptimizeAfterSubGraph)){ + return SUCCESS; + } + if(need_repass){ + AddImmediateRePassNode(node); + } + else{ + // clear attr on while + node->GetOpDesc()->DelAttr("need_infer_again_"); + } + } return SUCCESS; } void clear() { iter_nodes_.clear(); } @@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) { EXPECT_EQ(test_pass.GetRunTimes(), 1007); } */ + TEST_F(UTESTGraphPassesBasePass, while_loop) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(true); @@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); } + +/// data1 const +/// \ / +/// while +/// / \ +/// | | +/// cast1 cast2 +ComputeGraphPtr BuildWhileGraph1() { + // build sub graph + auto builder_sub = ut::GraphBuilder("sub"); + auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); + auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); + auto add = builder_sub.AddNode("add", ADD, 2, 1); + + builder_sub.AddDataEdge(data_1, 0, add, 0); + builder_sub.AddDataEdge(data_2, 0, add, 1); + auto sub_graph = builder_sub.GetGraph(); + sub_graph->SetName("while_sub"); + // build root graph + auto builder = ut::GraphBuilder("g1"); + auto data = builder.AddNode("data1", DATA, 0, 1); + auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); + auto c1 = builder.AddNode("cast1", CAST, 1, 1); + auto c2 = builder.AddNode("cast2", CAST, 1, 1); + // add while op + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape({1,1,1,1})); + tensor_desc->SetFormat(FORMAT_ND); + tensor_desc->SetDataType(DT_INT32); + + auto op_desc = std::make_shared("while", WHILE); + for (int i = 0; i < 2; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < 2; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + AttrUtils::SetBool(op_desc,"need_infer_again_", true); + op_desc->AddSubgraphName(sub_graph->GetName()); + op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); + auto root_graph = builder.GetGraph(); + auto while_op = root_graph->AddNode(op_desc); + + builder.AddDataEdge(data, 0, while_op, 0); + builder.AddDataEdge(const_op, 0, while_op, 1); + builder.AddDataEdge(while_op, 0, c1, 0); + builder.AddDataEdge(while_op, 1, c2, 0); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(while_op); + root_graph->AddSubgraph(sub_graph); + return root_graph; +} + +TEST_F(UTESTGraphPassesBasePass, while_infershape) { +NamesToPass names_to_pass; +auto test_pass = UtestTestPass(); +names_to_pass.push_back(std::make_pair("test", &test_pass)); + +auto graph = BuildWhileGraph1(); +auto ge_pass = GEPass(graph); +auto while_node = graph->FindNode("while"); +EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); +EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + } // namespace ge