modified: tests/ut/ge/graph/passes/base_pass_unittest.cc

pull/1251/head
zhaoxinxin 4 years ago
parent c067e32c68
commit f203c70cfd

@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass {
names_to_add_repass_.erase(iter);
}
}
// simulate infershape pass
if(node->GetType() == WHILE){
bool need_repass = false;
AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass);
if(!OptionExists(kOptimizeAfterSubGraph)){
return SUCCESS;
}
if(need_repass){
AddImmediateRePassNode(node);
}
else{
// clear attr on while
node->GetOpDesc()->DelAttr("need_infer_again_");
}
}
return SUCCESS;
}
void clear() { iter_nodes_.clear(); }
@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) {
EXPECT_EQ(test_pass.GetRunTimes(), 1007);
}
*/
TEST_F(UTESTGraphPassesBasePass, while_loop) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass(true);
@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) {
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
}
/// data1 const
/// \ /
/// while
/// / \
/// | |
/// cast1 cast2
ComputeGraphPtr BuildWhileGraph1() {
// build sub graph
auto builder_sub = ut::GraphBuilder("sub");
auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1);
auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1);
auto add = builder_sub.AddNode("add", ADD, 2, 1);
builder_sub.AddDataEdge(data_1, 0, add, 0);
builder_sub.AddDataEdge(data_2, 0, add, 1);
auto sub_graph = builder_sub.GetGraph();
sub_graph->SetName("while_sub");
// build root graph
auto builder = ut::GraphBuilder("g1");
auto data = builder.AddNode("data1", DATA, 0, 1);
auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1);
auto c1 = builder.AddNode("cast1", CAST, 1, 1);
auto c2 = builder.AddNode("cast2", CAST, 1, 1);
// add while op
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape({1,1,1,1}));
tensor_desc->SetFormat(FORMAT_ND);
tensor_desc->SetDataType(DT_INT32);
auto op_desc = std::make_shared<OpDesc>("while", WHILE);
for (int i = 0; i < 2; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < 2; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}
AttrUtils::SetBool(op_desc,"need_infer_again_", true);
op_desc->AddSubgraphName(sub_graph->GetName());
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName());
auto root_graph = builder.GetGraph();
auto while_op = root_graph->AddNode(op_desc);
builder.AddDataEdge(data, 0, while_op, 0);
builder.AddDataEdge(const_op, 0, while_op, 1);
builder.AddDataEdge(while_op, 0, c1, 0);
builder.AddDataEdge(while_op, 1, c2, 0);
sub_graph->SetParentGraph(root_graph);
sub_graph->SetParentNode(while_op);
root_graph->AddSubgraph(sub_graph);
return root_graph;
}
TEST_F(UTESTGraphPassesBasePass, while_infershape) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));
auto graph = BuildWhileGraph1();
auto ge_pass = GEPass(graph);
auto while_node = graph->FindNode("while");
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save