|
|
|
@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass {
|
|
|
|
|
names_to_add_repass_.erase(iter);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// simulate infershape pass
|
|
|
|
|
if(node->GetType() == WHILE){
|
|
|
|
|
bool need_repass = false;
|
|
|
|
|
AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass);
|
|
|
|
|
if(!OptionExists(kOptimizeAfterSubGraph)){
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
if(need_repass){
|
|
|
|
|
AddImmediateRePassNode(node);
|
|
|
|
|
}
|
|
|
|
|
else{
|
|
|
|
|
// clear attr on while
|
|
|
|
|
node->GetOpDesc()->DelAttr("need_infer_again_");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
void clear() { iter_nodes_.clear(); }
|
|
|
|
@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) {
|
|
|
|
|
EXPECT_EQ(test_pass.GetRunTimes(), 1007);
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
TEST_F(UTESTGraphPassesBasePass, while_loop) {
|
|
|
|
|
NamesToPass names_to_pass;
|
|
|
|
|
auto test_pass = UtestTestPass(true);
|
|
|
|
@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) {
|
|
|
|
|
auto ge_pass = GEPass(graph);
|
|
|
|
|
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// data1 const
|
|
|
|
|
/// \ /
|
|
|
|
|
/// while
|
|
|
|
|
/// / \
|
|
|
|
|
/// | |
|
|
|
|
|
/// cast1 cast2
|
|
|
|
|
ComputeGraphPtr BuildWhileGraph1() {
|
|
|
|
|
// build sub graph
|
|
|
|
|
auto builder_sub = ut::GraphBuilder("sub");
|
|
|
|
|
auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1);
|
|
|
|
|
auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1);
|
|
|
|
|
auto add = builder_sub.AddNode("add", ADD, 2, 1);
|
|
|
|
|
|
|
|
|
|
builder_sub.AddDataEdge(data_1, 0, add, 0);
|
|
|
|
|
builder_sub.AddDataEdge(data_2, 0, add, 1);
|
|
|
|
|
auto sub_graph = builder_sub.GetGraph();
|
|
|
|
|
sub_graph->SetName("while_sub");
|
|
|
|
|
// build root graph
|
|
|
|
|
auto builder = ut::GraphBuilder("g1");
|
|
|
|
|
auto data = builder.AddNode("data1", DATA, 0, 1);
|
|
|
|
|
auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1);
|
|
|
|
|
auto c1 = builder.AddNode("cast1", CAST, 1, 1);
|
|
|
|
|
auto c2 = builder.AddNode("cast2", CAST, 1, 1);
|
|
|
|
|
// add while op
|
|
|
|
|
auto tensor_desc = std::make_shared<GeTensorDesc>();
|
|
|
|
|
tensor_desc->SetShape(GeShape({1,1,1,1}));
|
|
|
|
|
tensor_desc->SetFormat(FORMAT_ND);
|
|
|
|
|
tensor_desc->SetDataType(DT_INT32);
|
|
|
|
|
|
|
|
|
|
auto op_desc = std::make_shared<OpDesc>("while", WHILE);
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
|
op_desc->AddInputDesc(tensor_desc->Clone());
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
|
op_desc->AddOutputDesc(tensor_desc->Clone());
|
|
|
|
|
}
|
|
|
|
|
AttrUtils::SetBool(op_desc,"need_infer_again_", true);
|
|
|
|
|
op_desc->AddSubgraphName(sub_graph->GetName());
|
|
|
|
|
op_desc->SetSubgraphInstanceName(0,sub_graph->GetName());
|
|
|
|
|
auto root_graph = builder.GetGraph();
|
|
|
|
|
auto while_op = root_graph->AddNode(op_desc);
|
|
|
|
|
|
|
|
|
|
builder.AddDataEdge(data, 0, while_op, 0);
|
|
|
|
|
builder.AddDataEdge(const_op, 0, while_op, 1);
|
|
|
|
|
builder.AddDataEdge(while_op, 0, c1, 0);
|
|
|
|
|
builder.AddDataEdge(while_op, 1, c2, 0);
|
|
|
|
|
sub_graph->SetParentGraph(root_graph);
|
|
|
|
|
sub_graph->SetParentNode(while_op);
|
|
|
|
|
root_graph->AddSubgraph(sub_graph);
|
|
|
|
|
return root_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(UTESTGraphPassesBasePass, while_infershape) {
|
|
|
|
|
NamesToPass names_to_pass;
|
|
|
|
|
auto test_pass = UtestTestPass();
|
|
|
|
|
names_to_pass.push_back(std::make_pair("test", &test_pass));
|
|
|
|
|
|
|
|
|
|
auto graph = BuildWhileGraph1();
|
|
|
|
|
auto ge_pass = GEPass(graph);
|
|
|
|
|
auto while_node = graph->FindNode("while");
|
|
|
|
|
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
|
|
|
|
|
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ge
|
|
|
|
|