|
|
|
@ -74,13 +74,134 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
|
|
|
|
|
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This is a simple representation of a graph.
|
|
|
|
|
// The BriefNode hold the pointer of the Node.
|
|
|
|
|
// This is to avoid changing the original graph
|
|
|
|
|
// in the process of trt graph analysis.
|
|
|
|
|
struct BriefNode {
|
|
|
|
|
explicit BriefNode(Node *n) { node = n; }
|
|
|
|
|
Node *node;
|
|
|
|
|
std::vector<BriefNode *> inlinks;
|
|
|
|
|
std::vector<BriefNode *> outlinks;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Union two adjacent BriefNode.
|
|
|
|
|
// Suppose we have two adjacent nodes src and dst.
|
|
|
|
|
// We will perform the following operations:
|
|
|
|
|
// 1. add all inputs(except src) of dst to src inlinks.
|
|
|
|
|
// 2. add all outputs of dst to src outlinks.
|
|
|
|
|
// 3. change all the dst's inputs and outputs
|
|
|
|
|
// corresponding inlinks and outlinks to src node.
|
|
|
|
|
// 4. delete all dst's inlinks and outlinks.
|
|
|
|
|
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map,
|
|
|
|
|
int src_id, int dst_id) {
|
|
|
|
|
// merge the two adjacent nodes into one node.
|
|
|
|
|
BriefNode *src_node = node_map.at(src_id);
|
|
|
|
|
BriefNode *dst_node = node_map.at(dst_id);
|
|
|
|
|
|
|
|
|
|
std::unordered_set<BriefNode *> inputs(src_node->inlinks.begin(),
|
|
|
|
|
src_node->inlinks.end());
|
|
|
|
|
std::unordered_set<BriefNode *> outputs;
|
|
|
|
|
|
|
|
|
|
for (auto *n : src_node->outlinks) {
|
|
|
|
|
if (n != dst_node) outputs.insert(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Add the inlinks and outlinks of dst node to src node.
|
|
|
|
|
std::vector<BriefNode *> dst_in_nodes = dst_node->inlinks;
|
|
|
|
|
for (BriefNode *node : dst_in_nodes) {
|
|
|
|
|
if (node != src_node) {
|
|
|
|
|
inputs.insert(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<BriefNode *> dst_out_nodes = dst_node->outlinks;
|
|
|
|
|
for (BriefNode *node : dst_out_nodes) {
|
|
|
|
|
outputs.insert(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// update the dst and src node's inlinks and outlinks.
|
|
|
|
|
src_node->inlinks =
|
|
|
|
|
std::move(std::vector<BriefNode *>(inputs.begin(), inputs.end()));
|
|
|
|
|
src_node->outlinks =
|
|
|
|
|
std::move(std::vector<BriefNode *>(outputs.begin(), outputs.end()));
|
|
|
|
|
dst_node->inlinks.clear();
|
|
|
|
|
dst_node->outlinks.clear();
|
|
|
|
|
|
|
|
|
|
auto inlink_or_outlink_cleaner = [&](std::vector<BriefNode *> &nodes) {
|
|
|
|
|
for (auto *&n : nodes) {
|
|
|
|
|
if (n == src_node || n == dst_node) {
|
|
|
|
|
n = src_node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
// Change all the dst inputs and outputs corresponding inlink and
|
|
|
|
|
// outlink to the src node.
|
|
|
|
|
for (auto *node : src_node->inlinks) {
|
|
|
|
|
inlink_or_outlink_cleaner(node->outlinks);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto *node : src_node->outlinks) {
|
|
|
|
|
inlink_or_outlink_cleaner(node->inlinks);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FlexibleDFS
|
|
|
|
|
// If reverse is true, do reverse dfs.
|
|
|
|
|
// If enter func is not nullptr, calls enter(node) before visiting any children
|
|
|
|
|
// of node.
|
|
|
|
|
// If leave func not nullptr, calls leave(node) after visiting all parents of
|
|
|
|
|
// node.
|
|
|
|
|
void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse,
|
|
|
|
|
const std::function<bool(const BriefNode *)> &enter,
|
|
|
|
|
const std::function<bool(const BriefNode *)> &leave) {
|
|
|
|
|
typedef struct {
|
|
|
|
|
const BriefNode *node;
|
|
|
|
|
bool leave;
|
|
|
|
|
} FNode;
|
|
|
|
|
|
|
|
|
|
std::vector<FNode> stack;
|
|
|
|
|
for (auto &node : source) {
|
|
|
|
|
stack.push_back(FNode{node, false});
|
|
|
|
|
}
|
|
|
|
|
std::unordered_set<const BriefNode *> visited;
|
|
|
|
|
while (!stack.empty()) {
|
|
|
|
|
auto fnode = stack.back();
|
|
|
|
|
stack.pop_back();
|
|
|
|
|
|
|
|
|
|
if (fnode.leave) {
|
|
|
|
|
if (leave && !leave(fnode.node)) return;
|
|
|
|
|
}
|
|
|
|
|
if (visited.count(fnode.node)) continue;
|
|
|
|
|
visited.insert(fnode.node);
|
|
|
|
|
|
|
|
|
|
if (enter && !enter(fnode.node)) return;
|
|
|
|
|
|
|
|
|
|
if (leave) stack.push_back(FNode{fnode.node, true});
|
|
|
|
|
const std::vector<BriefNode *> iter_nodes =
|
|
|
|
|
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
|
|
|
|
|
for (const BriefNode *node : iter_nodes) {
|
|
|
|
|
if (!visited.count(node)) {
|
|
|
|
|
stack.push_back(FNode{node, false});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
|
|
|
|
|
// Run the Extract algorithm to find all subgraphs.
|
|
|
|
|
std::vector<Node *> marked_nodes;
|
|
|
|
|
// We use brief_node_map to represent the original graph in order to avoid
|
|
|
|
|
// changing the original graph.
|
|
|
|
|
std::unordered_map<int, BriefNode *> brief_node_map;
|
|
|
|
|
|
|
|
|
|
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
|
|
|
|
|
brief_node_map[node.id()] = new BriefNode(&node);
|
|
|
|
|
if (node.attr(kMarkerAttrName).Bool()) {
|
|
|
|
|
marked_nodes.push_back(&node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extract sub-graphs in the marked node set, use Union Find algorithm.
|
|
|
|
|
node_map_t node_map; // id to ptr
|
|
|
|
|
for (auto *n : marked_nodes) {
|
|
|
|
@ -88,11 +209,73 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
|
|
|
|
|
n->attr(kUnionFindParent).Int32() = n->id();
|
|
|
|
|
node_map[n->id()] = n;
|
|
|
|
|
}
|
|
|
|
|
std::unordered_set<Node *> visited;
|
|
|
|
|
for (auto *n : marked_nodes) {
|
|
|
|
|
for (auto *out : n->outlinks) {
|
|
|
|
|
if (node_map.count(out->id())) {
|
|
|
|
|
UnionFindCombine(node_map, n->id(), out->id());
|
|
|
|
|
|
|
|
|
|
// create breif node map
|
|
|
|
|
for (auto &itr : brief_node_map) {
|
|
|
|
|
for (Node *node : itr.second->node->inlinks) {
|
|
|
|
|
itr.second->inlinks.push_back(brief_node_map[node->id()]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (Node *node : itr.second->node->outlinks) {
|
|
|
|
|
itr.second->outlinks.push_back(brief_node_map[node->id()]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &itr : brief_node_map) {
|
|
|
|
|
BriefNode *brief_node = itr.second;
|
|
|
|
|
|
|
|
|
|
if (!brief_node->node->attr(kMarkerAttrName).Bool()) {
|
|
|
|
|
VLOG(4) << brief_node->node->id() << " node not a trt candicate.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Our algorithm must guarantee that:
|
|
|
|
|
// 1. The graph is always directed acyclic graph(DAG).
|
|
|
|
|
// 2. If there is a path in the subgraph from X to Y (X and Y are both
|
|
|
|
|
// nodes in the subgraph), then all paths from X to Y are in the
|
|
|
|
|
// subgraph.
|
|
|
|
|
//
|
|
|
|
|
// In order to achieve the above guarantee.
|
|
|
|
|
// For adjacent nodes src -> dst.
|
|
|
|
|
// 1. Get all dst input nodes except src.
|
|
|
|
|
// 2. Reverse DFS from those input nodes
|
|
|
|
|
// 3. If there is a path from input nodes to src,
|
|
|
|
|
// then the src and dst nodes can not be fused into one node,
|
|
|
|
|
// otherwise it can be done.
|
|
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
|
std::unordered_set<BriefNode *> contract_nodes;
|
|
|
|
|
for (auto *out : brief_node->outlinks) {
|
|
|
|
|
// must be an trt candidate
|
|
|
|
|
if (!out->node->attr(kMarkerAttrName).Bool()) continue;
|
|
|
|
|
// get all dst input nodes except src.
|
|
|
|
|
std::vector<BriefNode *> source_nodes;
|
|
|
|
|
for (auto *n : out->inlinks) {
|
|
|
|
|
if (n != brief_node) {
|
|
|
|
|
source_nodes.push_back(n);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Reverse DFS from the source_nodes.
|
|
|
|
|
bool have_excess_path = false;
|
|
|
|
|
FlexibleDFS(source_nodes, true, nullptr,
|
|
|
|
|
[&have_excess_path, brief_node](const BriefNode *n) {
|
|
|
|
|
if (n == brief_node) {
|
|
|
|
|
have_excess_path = true;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
if (have_excess_path) continue;
|
|
|
|
|
contract_nodes.insert(out);
|
|
|
|
|
}
|
|
|
|
|
if (contract_nodes.empty()) break;
|
|
|
|
|
|
|
|
|
|
for (auto dst_node : contract_nodes) {
|
|
|
|
|
UnionFindCombine(node_map, brief_node->node->id(),
|
|
|
|
|
dst_node->node->id());
|
|
|
|
|
UnionContractedNodes(brief_node_map, brief_node->node->id(),
|
|
|
|
|
dst_node->node->id());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -128,6 +311,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
|
|
|
|
|
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
|
|
|
|
|
block_node->inlinks = std::move(io.first);
|
|
|
|
|
block_node->outlinks = std::move(io.second);
|
|
|
|
|
|
|
|
|
|
for (auto *node : subgraph) {
|
|
|
|
|
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
|
|
|
|
|
// pass.
|
|
|
|
|