|
|
|
@ -31,10 +31,10 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace {
|
|
|
|
|
void SortHelper(
|
|
|
|
|
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
|
|
|
|
|
ir::Node *node, std::unordered_set<ir::Node *> *visited,
|
|
|
|
|
std::vector<ir::Node *> *ret) {
|
|
|
|
|
void SortHelper(const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>,
|
|
|
|
|
ir::NodeComp> &adj_list,
|
|
|
|
|
ir::Node *node, std::unordered_set<ir::Node *> *visited,
|
|
|
|
|
std::vector<ir::Node *> *ret) {
|
|
|
|
|
visited->insert(node);
|
|
|
|
|
|
|
|
|
|
for (auto adj : adj_list.at(node)) {
|
|
|
|
@ -50,7 +50,8 @@ void SortHelper(
|
|
|
|
|
|
|
|
|
|
bool HasCircleHelper(
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
|
|
|
|
|
const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
|
|
|
|
|
&adj_list,
|
|
|
|
|
std::unordered_set<ir::Node *> *visited,
|
|
|
|
|
std::unordered_set<ir::Node *> *in_trace,
|
|
|
|
|
std::vector<std::vector<ir::Node *>> *circles) {
|
|
|
|
@ -84,7 +85,8 @@ bool HasCircleHelper(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasCircleInternal(
|
|
|
|
|
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
|
|
|
|
|
const std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
|
|
|
|
|
&adj_list,
|
|
|
|
|
std::vector<std::vector<ir::Node *>> *circles) {
|
|
|
|
|
std::unordered_set<ir::Node *> visited;
|
|
|
|
|
std::unordered_set<ir::Node *> in_trace;
|
|
|
|
@ -107,8 +109,8 @@ bool FindCircleSubGraph(const Graph &graph,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
|
|
|
|
|
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
|
|
|
|
|
BuildOperationAdjList(graph);
|
|
|
|
|
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
|
|
|
|
|
adj_list = BuildOperationAdjList(graph);
|
|
|
|
|
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr));
|
|
|
|
|
std::unordered_set<ir::Node *> visited;
|
|
|
|
|
std::vector<ir::Node *> ret;
|
|
|
|
@ -117,34 +119,30 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
|
|
|
|
|
SortHelper(adj_list, adj.first, &visited, &ret);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build operator inlink edge table.
|
|
|
|
|
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
|
|
|
|
|
const Graph &graph) {
|
|
|
|
|
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
|
|
|
|
|
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
|
|
|
|
|
BuildOperationAdjList(const Graph &graph) {
|
|
|
|
|
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
|
|
|
|
|
adj_list;
|
|
|
|
|
|
|
|
|
|
for (auto &n : graph.Nodes()) {
|
|
|
|
|
if (!n->IsOp()) continue;
|
|
|
|
|
if (adj_list.find(n) == adj_list.end()) {
|
|
|
|
|
adj_list[n] = std::unordered_set<ir::Node *>();
|
|
|
|
|
adj_list[n] = std::set<ir::Node *, ir::NodeComp>();
|
|
|
|
|
}
|
|
|
|
|
std::vector<ir::Node *> nodes;
|
|
|
|
|
for (auto &var : n->inputs) {
|
|
|
|
|
for (auto &adj_n : var->inputs) {
|
|
|
|
|
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
|
|
|
|
|
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
|
|
|
|
|
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
|
|
|
|
|
<< " via " << var->Name() << reinterpret_cast<void *>(var);
|
|
|
|
|
nodes.push_back(adj_n);
|
|
|
|
|
adj_list[n].insert(adj_n);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::sort(nodes.begin(), nodes.end(), [](ir::Node *node1, ir::Node *node2) {
|
|
|
|
|
return node1->id() > node2->id();
|
|
|
|
|
});
|
|
|
|
|
adj_list[n].insert(std::make_move_iterator(nodes.begin()),
|
|
|
|
|
std::make_move_iterator(nodes.end()));
|
|
|
|
|
}
|
|
|
|
|
return adj_list;
|
|
|
|
|
}
|
|
|
|
|