You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/ir/graph_helper.cc

373 lines
11 KiB

7 years ago
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
7 years ago
#include <algorithm>
#include <deque>
#include <fstream>
#include <iosfwd>
#include <ostream>
#include <stack>
#include <unordered_map>
7 years ago
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_traits.h"
7 years ago
DEFINE_string(print_sub_graph_dir, "",
"FLAGS_print_sub_graph_dir is used "
"to print the nodes of sub_graphs.");
7 years ago
namespace paddle {
namespace framework {
namespace ir {
namespace {
void SortHelper(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
}
}
VLOG(5) << "topology sort insert: " << node->Name() << " "
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
7 years ago
ret->push_back(node);
}
bool HasCircleHelper(
ir::Node *node,
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace,
std::vector<std::vector<ir::Node *>> *circles) {
7 years ago
if (visited->find(node) == visited->end()) {
visited->insert(node);
in_trace->insert(node);
for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace, circles)) {
7 years ago
return true;
} else if (in_trace->find(in) != in_trace->end()) {
if (circles != nullptr) {
std::vector<ir::Node *> circle;
circle.emplace_back(in);
ir::Node *p = in;
for (auto &adj : adj_list.at(p)) {
if (in_trace->count(adj)) {
circle.emplace_back(adj);
p = adj;
}
}
circles->emplace_back(circle);
}
7 years ago
return true;
}
}
}
in_trace->erase(node);
return false;
}
7 years ago
bool HasCircleInternal(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::vector<std::vector<ir::Node *>> *circles) {
7 years ago
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) {
7 years ago
return true;
}
}
return false;
}
7 years ago
} // namespace
bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
}
bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles) {
return HasCircleInternal(BuildOperationAdjList(graph), circles);
7 years ago
}
7 years ago
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
7 years ago
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr));
7 years ago
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &ret);
}
}
return ret;
}
// Build operator inlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
7 years ago
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
7 years ago
for (auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
7 years ago
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
7 years ago
}
std::vector<ir::Node *> nodes;
7 years ago
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
nodes.push_back(adj_n);
7 years ago
}
}
std::sort(nodes.begin(), nodes.end(), [](ir::Node *node1, ir::Node *node2) {
return node1->id() > node2->id();
});
adj_list[n].insert(std::make_move_iterator(nodes.begin()),
std::make_move_iterator(nodes.end()));
7 years ago
}
return adj_list;
}
// Build operator outlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
}
for (auto &var : n->outputs) {
for (auto &adj_n : var->outputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
}
}
}
return adj_list;
}
std::vector<ir::Node *> OpDFSSort(const Graph &graph) {
auto edge_table = BuildOperationOutAdjList(graph);
std::stack<Node *> stack;
for (auto &ele : edge_table) {
if (ele.first->inputs.empty()) {
// find the input ops (those without input vars)
stack.push(ele.first);
} else {
// find the ops with only persistable vars as inputs.
bool all_persistable = true;
for (auto *input : ele.first->inputs) {
if (!(input->IsVar() && input->Var() && input->Var()->Persistable())) {
all_persistable = false;
}
}
if (all_persistable) {
stack.push(ele.first);
}
}
}
std::vector<Node *> res;
// start from the feed op and DFS
std::unordered_set<Node *> unique_set;
while (!stack.empty()) {
// will start from the last feed by default.
auto cur = stack.top();
stack.pop();
unique_set.insert(cur);
res.push_back(cur);
for (auto *op : edge_table[cur]) {
if (!unique_set.count(op)) {
stack.push(op);
}
}
}
return res;
}
std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph) {
std::vector<ir::Node *> nodes;
std::unordered_map<Node *, int> in_degree;
auto set_out_ops_ready = [&](Node *var) {
for (auto *op : var->outputs) {
--in_degree[op];
}
};
// build in_degree
for (auto *node : graph.Nodes()) {
if (node->IsOp()) {
in_degree[node] += node->inputs.size();
} else if (node->IsVar() && node->inputs.empty()) {
// put all the inputs of the whole graph ready.
set_out_ops_ready(node);
}
}
std::deque<Node *> op_queue;
// first visit
for (auto &node : OpDFSSort(graph)) {
if (node->IsOp()) {
op_queue.push_back(node);
}
}
// traverse the graph
int num_ops = op_queue.size();
while (num_ops) {
for (auto it = op_queue.begin(); it != op_queue.end(); it++) {
auto *&cur_op = *it;
if (!cur_op || in_degree[cur_op] > 0) continue;
// visit this node
// put all the output var of this op valid.
for (auto *out_var : cur_op->outputs) {
if (!out_var) continue;
set_out_ops_ready(out_var);
}
VLOG(8) << "visit " << cur_op->Name();
nodes.push_back(cur_op);
cur_op = nullptr;
num_ops--;
}
}
return nodes;
}
size_t GraphNum(const Graph &graph) {
std::unordered_set<ir::Node *> nodes(graph.Nodes());
std::unordered_set<ir::Node *> visited_nodes;
visited_nodes.reserve(nodes.size());
std::deque<ir::Node *> q_nodes;
std::vector<std::unordered_set<ir::Node *>> graph_nodes;
std::unordered_set<ir::Node *> g_nodes;
// q_set used to record records in the queue.
std::unordered_set<ir::Node *> q_set;
size_t graph_count = 0;
auto traverse_nodes = [&visited_nodes, &q_nodes,
&q_set](const std::vector<ir::Node *> &nodes) {
for (auto n : nodes) {
if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
q_nodes.push_back(n);
q_set.insert(n);
}
}
};
while (visited_nodes.size() != nodes.size()) {
if (!q_nodes.empty()) {
auto cur_node = q_nodes.front();
q_nodes.pop_front();
q_set.erase(cur_node);
visited_nodes.insert(cur_node);
g_nodes.insert(cur_node);
traverse_nodes(cur_node->inputs);
traverse_nodes(cur_node->outputs);
} else {
++graph_count;
if (g_nodes.size()) {
graph_nodes.emplace_back(g_nodes);
}
g_nodes.clear();
for (auto &n : nodes) {
if (visited_nodes.count(n) == 0) {
q_nodes.push_back(n);
q_set.insert(n);
break;
}
}
}
}
if (g_nodes.size()) {
graph_nodes.emplace_back(g_nodes);
}
if (FLAGS_print_sub_graph_dir.size()) {
if (graph_nodes.size() > 1) {
std::stringstream out;
for (auto &g_n : graph_nodes) {
out << "graph_nodes: " << g_n.size() << "\n";
}
out << "\n\n";
for (auto &g_n : graph_nodes) {
out << "graph_nodes: " << g_n.size();
for (auto &node : g_n) {
out << "\nNode: " << node->Name() << " in [";
for (auto &n : node->inputs) {
out << n->Name() << ", ";
}
out << "], out[";
for (auto &n : node->outputs) {
out << n->Name() << ", ";
}
out << "]";
}
out << "\n\n\n";
}
std::unique_ptr<std::ostream> fout(
new std::ofstream(FLAGS_print_sub_graph_dir));
PADDLE_ENFORCE(fout->good());
*fout << out.str();
}
}
return graph_count;
}
void CleanIndividualNodes(Graph *graph) {
std::unordered_set<Node *> nodes2rm;
for (auto *node : graph->Nodes()) {
if (node->inputs.empty() && node->outputs.empty()) {
nodes2rm.insert(node);
}
}
for (auto *node : nodes2rm) {
graph->RemoveNode(node);
}
}
std::vector<Node *> TopologyVarientSort(const Graph &graph,
SortKind sort_kind) {
switch (sort_kind) {
case SortKind::TS:
return framework::ir::TopologySortOperations(graph);
default:
return framework::ir::TopologyDfsSortOperations(graph);
}
}
7 years ago
} // namespace ir
} // namespace framework
} // namespace paddle