|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
|
|
|
|
|
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_set<Node *> visited;
|
|
|
|
|
std::set<Node *> to_visit{source.begin(), source.end()};
|
|
|
|
|
|
|
|
|
|
std::vector<Node *> inlink_visited;
|
|
|
|
|
std::vector<Node *> inlink_sorted;
|
|
|
|
|
while (!to_visit.empty()) {
|
|
|
|
|
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
|
|
|
|
|
for (auto *p : queue) {
|
|
|
|
|
inlink_visited.clear();
|
|
|
|
|
|
|
|
|
|
std::copy_if(p->inputs.begin(), p->inputs.end(),
|
|
|
|
|
std::back_inserter(inlink_visited),
|
|
|
|
|
[&](Node *x) -> bool { return visited.count(x) != 0; });
|
|
|
|
|
|
|
|
|
|
if (inlink_visited.size() == p->inputs.size()) {
|
|
|
|
|
sorted_.push_back(p);
|
|
|
|
|
for (auto *_ : p->outputs) {
|
|
|
|
|
if (!visited.count(_)) {
|
|
|
|
|
to_visit.insert(_);
|
|
|
|
|
}
|
|
|
|
|
to_visit.erase(p);
|
|
|
|
|
sorted_.push_back(p);
|
|
|
|
|
for (auto *out : p->outputs) {
|
|
|
|
|
inlink_sorted.clear();
|
|
|
|
|
std::copy_if(out->inputs.begin(), out->inputs.end(),
|
|
|
|
|
std::back_inserter(inlink_sorted), [&](Node *x) -> bool {
|
|
|
|
|
return std::find(sorted_.begin(), sorted_.end(), x) !=
|
|
|
|
|
sorted_.end();
|
|
|
|
|
});
|
|
|
|
|
if (inlink_sorted.size() == out->inputs.size()) {
|
|
|
|
|
to_visit.insert(out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
to_visit.erase(p);
|
|
|
|
|
visited.insert(p);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|