You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
6.0 KiB
161 lines
6.0 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
|
|
|
namespace paddle {
|
|
namespace inference {
|
|
namespace analysis {
|
|
|
|
const char *SubGraphSplitter::kMarkerAttrName =
|
|
"_sub_graph_splitter_inside_sub_graph";
|
|
|
|
std::vector<std::vector<Node *>> SubGraphSplitter::operator()() {
|
|
MarkNodesInsideSubGraph();
|
|
return ExtractSubGraphs();
|
|
}
|
|
|
|
// Mark the output variables inside a subgraph with the func.
|
|
inline void MarkOutLinksInSubGraph(const Function *func) {
|
|
for (auto *var : func->outlinks) {
|
|
var->attr(SubGraphSplitter::kMarkerAttrName).Bool() = true;
|
|
}
|
|
}
|
|
|
|
void SubGraphSplitter::MarkNodesInsideSubGraph() {
|
|
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
|
|
if (node_inside_subgraph_teller_(&node)) {
|
|
node.attr(kMarkerAttrName).Bool() = true;
|
|
if (node.type() == Node::Type::kFunction) {
|
|
// If a function is inside the sub-graph, mark all the output variables
|
|
// to be inside too, so that two marked functions will be inside a same
|
|
// sub-graph, lets take a example: A_function->var->B_function, if
|
|
// A_function is marked, var should also be marked, so that B_function
|
|
// will be in the same sub-graph with A_function if B_function is
|
|
// marked.
|
|
MarkOutLinksInSubGraph(static_cast<const Function *>(&node));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
const char *kUnionFindParent = "_sub_graph_splitter_union_find_parent_";
|
|
|
|
// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
|
|
// a's output is node b, that is a and b is in the same sub-graph. The UF
|
|
// algorithm will group them to the same cluster.
|
|
using node_map_t = std::unordered_map<int, Node *>;
|
|
// Find the ancestor id of a node.
|
|
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
|
|
int tmp = id;
|
|
do {
|
|
tmp = node_map.at(tmp)->attr(kUnionFindParent).Int32();
|
|
} while (node_map.at(tmp)->attr(kUnionFindParent).Int32() != tmp);
|
|
return tmp;
|
|
}
|
|
// Make this two node share the same ancestor.
|
|
// TODO(Superjom) bad performance, make a balanced tree latter.
|
|
void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
|
|
int a_ancestor = UnionFindGetAncestor(node_map, a);
|
|
int b_ancestor = UnionFindGetAncestor(node_map, b);
|
|
node_map.at(b_ancestor)->attr(kUnionFindParent).Int32() = a_ancestor;
|
|
node_map.at(a)->attr(kUnionFindParent).Int32() = a_ancestor;
|
|
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
|
|
}
|
|
|
|
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
|
|
std::vector<Node *> marked_nodes;
|
|
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes_in_TS()) {
|
|
if (node.attr(kMarkerAttrName).Bool()) {
|
|
marked_nodes.push_back(&node);
|
|
}
|
|
}
|
|
// extract sub-graphs in the marked node set, use Union Find algorithm.
|
|
node_map_t node_map; // id to ptr
|
|
for (auto *n : marked_nodes) {
|
|
// n's parent == n.id means it is the ancestor
|
|
n->attr(kUnionFindParent).Int32() = n->id();
|
|
node_map[n->id()] = n;
|
|
}
|
|
std::unordered_set<Node *> visited;
|
|
for (auto *n : marked_nodes) {
|
|
for (auto *out : n->outlinks) {
|
|
if (node_map.count(out->id())) {
|
|
UnionFindCombine(node_map, n->id(), out->id());
|
|
}
|
|
}
|
|
}
|
|
|
|
std::unordered_map<int /*ancestor*/, std::vector<Node *>> clusters;
|
|
for (auto *n : marked_nodes) {
|
|
if (n->type() == Node::Type::kFunction) {
|
|
clusters[UnionFindGetAncestor(node_map,
|
|
n->attr(kUnionFindParent).Int32())]
|
|
.push_back(n);
|
|
}
|
|
}
|
|
std::vector<std::vector<Node *>> result;
|
|
std::for_each(clusters.begin(), clusters.end(),
|
|
[&](const decltype(clusters)::value_type &it) {
|
|
result.push_back(it.second);
|
|
});
|
|
|
|
return result;
|
|
}
|
|
|
|
void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
|
|
|
|
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
|
|
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
|
|
for (auto &subgraph : subgraphs) {
|
|
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
|
|
// replace this sub-graph with the first node. Two steps: 1. Create a Block
|
|
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph
|
|
// as deleted. 3. Replace the deleted node with the new Block Node.
|
|
auto *block_node = static_cast<FunctionBlock *>(
|
|
graph_->nodes.Create(Node::Type::kFunctionBlock));
|
|
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
|
|
block_node->inlinks = std::move(io.first);
|
|
block_node->outlinks = std::move(io.second);
|
|
for (auto *node : subgraph) {
|
|
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
|
|
// pass.
|
|
node->SetDeleted();
|
|
block_node->subgraph.push_back(node);
|
|
}
|
|
|
|
// Change all the sub-graph's inputs and outputs corresponding inlink and
|
|
// outlink to this sub-graph node.
|
|
auto inlink_or_outlink_cleaner = [&](std::vector<Node *> &nodes) {
|
|
for (auto *&n : nodes) {
|
|
if (subgraph_uniq.count(n)) {
|
|
n = block_node;
|
|
}
|
|
}
|
|
std::unordered_set<Node *> uniq(nodes.begin(), nodes.end());
|
|
nodes.assign(uniq.begin(), uniq.end());
|
|
};
|
|
for (auto *i : block_node->inlinks) {
|
|
inlink_or_outlink_cleaner(i->outlinks);
|
|
}
|
|
for (auto *&o : block_node->outlinks) {
|
|
inlink_or_outlink_cleaner(o->inlinks);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace analysis
|
|
} // namespace inference
|
|
} // namespace paddle
|