|
|
|
@ -160,6 +160,77 @@ TEST(DataFlowGraph, Build_IR_Graph) {
|
|
|
|
|
ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FlexibleDFS
|
|
|
|
|
/*
|
|
|
|
|
* Graph topology
|
|
|
|
|
* inputs: 0
|
|
|
|
|
* 0 -> 1
|
|
|
|
|
* 1 -> 2
|
|
|
|
|
* 1 -> 3
|
|
|
|
|
* 3 -> 4
|
|
|
|
|
* 4 -> 5
|
|
|
|
|
* 5 -> 2
|
|
|
|
|
*/
|
|
|
|
|
TEST(DataFlowGraph, flexibledfs) {
|
|
|
|
|
DataFlowGraph graph;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 6; i++) {
|
|
|
|
|
auto* node = graph.nodes.Create(Node::Type::kValue);
|
|
|
|
|
node->SetName("node-" + std::to_string(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto add_link = [&](int i, int j) {
|
|
|
|
|
Node* source = graph.nodes.GetMutable(i);
|
|
|
|
|
Node* target = graph.nodes.GetMutable(j);
|
|
|
|
|
target->inlinks.push_back(source);
|
|
|
|
|
source->outlinks.push_back(target);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
add_link(0, 1);
|
|
|
|
|
add_link(1, 2);
|
|
|
|
|
add_link(1, 3);
|
|
|
|
|
add_link(3, 4);
|
|
|
|
|
add_link(4, 5);
|
|
|
|
|
add_link(5, 2);
|
|
|
|
|
graph.Build();
|
|
|
|
|
|
|
|
|
|
std::vector<const Node*> order;
|
|
|
|
|
FlexibleDFS(graph.inputs(), false, nullptr, [&order](const Node* n) {
|
|
|
|
|
order.push_back(n);
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(order.size(), 6UL);
|
|
|
|
|
|
|
|
|
|
order.clear();
|
|
|
|
|
// reverse dfs
|
|
|
|
|
FlexibleDFS(graph.outputs(), true, nullptr, [&order](const Node* n) {
|
|
|
|
|
order.push_back(n);
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(order.size(), 6UL);
|
|
|
|
|
|
|
|
|
|
// If we delete
|
|
|
|
|
Node* last_node = graph.nodes.GetMutable(2);
|
|
|
|
|
Node* direct_node = graph.nodes.GetMutable(1);
|
|
|
|
|
std::vector<Node*> source_nodes;
|
|
|
|
|
for (Node* node : last_node->inlinks) {
|
|
|
|
|
if (node != direct_node) source_nodes.push_back(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool has_cycle = false;
|
|
|
|
|
FlexibleDFS(source_nodes, true, nullptr,
|
|
|
|
|
[&has_cycle, direct_node](const Node* n) {
|
|
|
|
|
if (n == direct_node) {
|
|
|
|
|
has_cycle = true;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
ASSERT_TRUE(has_cycle);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace analysis
|
|
|
|
|
} // namespace inference
|
|
|
|
|
} // namespace paddle
|
|
|
|
|