commit
f55e8901c8
@ -0,0 +1,187 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
|
||||
nodes_.emplace_back(new PDNode(std::move(teller), name));
|
||||
auto* cur = nodes_.back().get();
|
||||
return cur;
|
||||
}
|
||||
|
||||
void PDPattern::AddEdge(PDNode* a, PDNode* b) {
|
||||
PADDLE_ENFORCE(a);
|
||||
PADDLE_ENFORCE(b);
|
||||
PADDLE_ENFORCE(a != b, "can't connect to the same nodes.");
|
||||
edges_.emplace_back(a, b);
|
||||
}
|
||||
|
||||
void GraphPatternDetecter::operator()(Graph* graph,
|
||||
GraphPatternDetecter::handle_t handler) {
|
||||
if (!MarkPDNodesInGraph(*graph)) return;
|
||||
auto subgraphs = DetectPatterns();
|
||||
UniquePatterns(&subgraphs);
|
||||
RemoveOverlappedMatch(&subgraphs);
|
||||
|
||||
for (auto& g : subgraphs) {
|
||||
handler(g, graph);
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) {
|
||||
if (graph.Nodes().empty()) return false;
|
||||
|
||||
for (auto& node : GraphTraits::DFS(graph)) {
|
||||
for (const auto& pdnode : pattern_.nodes()) {
|
||||
if (pdnode->Tell(&node)) {
|
||||
pdnodes2nodes_[pdnode.get()].insert(&node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return !pdnodes2nodes_.empty();
|
||||
}
|
||||
|
||||
struct HitGroup {
|
||||
std::unordered_map<PDNode*, Node*> roles;
|
||||
|
||||
bool Match(Node* node, PDNode* pat) {
|
||||
return !roles.count(pat) || roles.at(pat) == node;
|
||||
}
|
||||
|
||||
void Register(Node* node, PDNode* pat) { roles[pat] = node; }
|
||||
};
|
||||
|
||||
// Tell whether Node a links to b.
|
||||
bool IsNodesLink(Node* a, Node* b) {
|
||||
for (auto* node : a->outputs) {
|
||||
if (b == node) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<GraphPatternDetecter::subgraph_t>
|
||||
GraphPatternDetecter::DetectPatterns() {
|
||||
// Init empty subgraphs.
|
||||
std::vector<GraphPatternDetecter::subgraph_t> result;
|
||||
std::vector<HitGroup> init_groups;
|
||||
PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
|
||||
auto* first_pnode = pattern_.edges().front().first;
|
||||
if (!pdnodes2nodes_.count(first_pnode)) return result;
|
||||
for (auto* node : pdnodes2nodes_[first_pnode]) {
|
||||
HitGroup group;
|
||||
group.roles[first_pnode] = node;
|
||||
init_groups.emplace_back(group);
|
||||
}
|
||||
|
||||
int step = 0;
|
||||
std::array<std::vector<HitGroup>, 2> bi_records;
|
||||
bi_records[0] = std::move(init_groups);
|
||||
|
||||
// Extend a PDNode to subgraphs by deducing the connection relations defined
|
||||
// in edges of PDNodes.
|
||||
for (const auto& edge : pattern_.edges()) {
|
||||
// Each role has two PDNodes, which indicates two roles.
|
||||
// Detect two Nodes that can match these two roles and they are connected.
|
||||
auto& pre_groups = bi_records[step % 2];
|
||||
auto& cur_groups = bi_records[1 - (step++ % 2)];
|
||||
cur_groups.clear();
|
||||
// source -> target
|
||||
for (Node* source : pdnodes2nodes_[edge.first]) {
|
||||
for (Node* target : pdnodes2nodes_[edge.second]) {
|
||||
// TODO(Superjomn) add some prune strategies.
|
||||
for (const auto& group : pre_groups) {
|
||||
HitGroup new_group = group;
|
||||
if (IsNodesLink(source, target) &&
|
||||
new_group.Match(source, edge.first)) {
|
||||
new_group.Register(source, edge.first);
|
||||
if (new_group.Match(target, edge.second)) {
|
||||
new_group.Register(target, edge.second);
|
||||
cur_groups.push_back(new_group);
|
||||
// TODO(Superjomn) need to unique
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& group : bi_records[step % 2]) {
|
||||
GraphPatternDetecter::subgraph_t subgraph;
|
||||
for (auto& role : group.roles) {
|
||||
subgraph.emplace(role.first, role.second);
|
||||
}
|
||||
result.emplace_back(subgraph);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void GraphPatternDetecter::UniquePatterns(
|
||||
std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) {
|
||||
if (subgraphs->empty()) return;
|
||||
std::vector<GraphPatternDetecter::subgraph_t> result;
|
||||
|
||||
std::unordered_set<size_t> set;
|
||||
for (auto& g : *subgraphs) {
|
||||
size_t key = 0;
|
||||
for (auto& item : g) {
|
||||
key ^= std::hash<void*>{}(item.first);
|
||||
key ^= std::hash<void*>{}(item.second);
|
||||
}
|
||||
if (!set.count(key)) {
|
||||
result.emplace_back(g);
|
||||
set.insert(key);
|
||||
}
|
||||
}
|
||||
*subgraphs = result;
|
||||
}
|
||||
|
||||
void GraphPatternDetecter::RemoveOverlappedMatch(
|
||||
std::vector<subgraph_t>* subgraphs) {
|
||||
std::vector<subgraph_t> result;
|
||||
std::unordered_set<Node*> node_set;
|
||||
|
||||
for (const auto& subgraph : *subgraphs) {
|
||||
bool valid = true;
|
||||
for (auto& item : subgraph) {
|
||||
if (node_set.count(item.second)) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (valid) {
|
||||
for (auto& item : subgraph) {
|
||||
node_set.insert(item.second);
|
||||
}
|
||||
result.push_back(subgraph);
|
||||
}
|
||||
}
|
||||
*subgraphs = result;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,181 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef PADDLE_WITH_TESTING
|
||||
#include <gtest/gtest_prod.h>
|
||||
#endif
|
||||
|
||||
#include <numeric>
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/node.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
// Some basic torminolygies:
|
||||
// - PDPattern: a pattern defined as a data flow graph.
|
||||
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
|
||||
// that meets some conditions defined in `PDNode.teller`.
|
||||
// - A pattern is defined with PDNodes with edges.
|
||||
|
||||
// Pattern detector node. This node helps to build a pattern.
|
||||
struct PDNode {
|
||||
// tell whether an ir::Node* is a candidation for a PDNode.
|
||||
using teller_t = std::function<bool(Node*)>;
|
||||
|
||||
PDNode(teller_t&& teller, const std::string& name = "")
|
||||
: teller_(teller), name_(name) {
|
||||
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
|
||||
}
|
||||
|
||||
PDNode(PDNode&& other) = default;
|
||||
|
||||
std::vector<PDNode*> inlinks;
|
||||
std::vector<PDNode*> outlinks;
|
||||
|
||||
bool Tell(Node* node) const {
|
||||
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
|
||||
return teller_(node);
|
||||
}
|
||||
|
||||
const std::string& name() const { return name_; }
|
||||
|
||||
PDNode(const PDNode&) = delete;
|
||||
PDNode& operator=(const PDNode&) = delete;
|
||||
|
||||
private:
|
||||
teller_t teller_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
/*
|
||||
* A pattern in a graph, which defined with PDNode and edges. Most graph
|
||||
* patterns can be divided into PDNodes and link relations between them.
|
||||
*
|
||||
* For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD
|
||||
* operators from the computation graph, the MUL's output should have only one
|
||||
* consumer which is the ELEMENTWISE_ADD.
|
||||
* This pattern can be defined as with the following pseudo codes
|
||||
*
|
||||
* // Create two operator PDNodes.
|
||||
* MUL = PDPattern.NewNode()
|
||||
* ELE = PDPattern.NewNode()
|
||||
* // Create the variable PDNodes.
|
||||
* MUL_out = PDPattern.NewNode()
|
||||
* // Add teller to define some rules that help to filter the target Nodes.
|
||||
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
|
||||
* ELE.teller = lambda(node): \
|
||||
* node->IsOp() && node->Op()->Type == "elementwise_add";
|
||||
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
|
||||
* && (ELE in node->outputs)
|
||||
*
|
||||
* One can add more specific tellers for PDNodes or edges, both the Operator
|
||||
* and Variable Nodes can be ruled in PDNode.teller.
|
||||
*
|
||||
* PDPattern can record the general patterns, such as the pattern represents
|
||||
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
|
||||
* - Ops whose inputs and outputs share the same variables
|
||||
*/
|
||||
class PDPattern {
|
||||
public:
|
||||
using edge_t = std::pair<PDNode*, PDNode*>;
|
||||
|
||||
void AddEdge(PDNode* a, PDNode* b);
|
||||
|
||||
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = "");
|
||||
|
||||
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
|
||||
const std::vector<edge_t>& edges() const { return edges_; }
|
||||
|
||||
private:
|
||||
#ifdef PADDLE_WITH_TESTING
|
||||
FRIEND_TEST(PDPattern, AddEdge);
|
||||
FRIEND_TEST(PDPattern, NewNode);
|
||||
#endif
|
||||
|
||||
std::vector<std::unique_ptr<PDNode>> nodes_;
|
||||
std::vector<edge_t> edges_;
|
||||
};
|
||||
|
||||
/*
|
||||
* GraphPatternDetecter helps to detect the specific patterns in the graph.
|
||||
* Input a pattern, output a list of the matched subgraphs/nodes.
|
||||
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
|
||||
*
|
||||
* The algorithm has three phases:
|
||||
* 1. Mark the nodes that match the defined PDNodes in a PDPattern,
|
||||
* 2. Extend a PDNode to subgraphs by deducing the connection relation defined
|
||||
* in PAPattern(the edges),
|
||||
* 3. Get the filtered subgraphs and treat them with a pre-defined handler.
|
||||
*
|
||||
* Usage:
|
||||
* // Create a detector
|
||||
* GraphPatternDetecter detector;
|
||||
* // Define the detector's pattern, by adding PDNode and define the edges.
|
||||
* auto* node0 = detector.mutable_pattern().AddNode(...)
|
||||
* auto* node1 = detector.mutable_pattern().AddNode(...)
|
||||
* node0->teller = some lambda.
|
||||
* node1->teller = some lambda.
|
||||
* detector.mutable_pattern().AddEdge(node0, node1);
|
||||
* // Create an handler, to define the behavior of treating the filtered
|
||||
* // subgraphs that comply with the patterns.
|
||||
* GraphPatternDetecter::handle_t handler = some labmda
|
||||
* // Execute the detector.
|
||||
* detector(&graph, handler);
|
||||
*/
|
||||
class GraphPatternDetecter {
|
||||
public:
|
||||
using subgraph_t = std::unordered_map<PDNode*, Node*>;
|
||||
|
||||
// Operate on the detected pattern.
|
||||
using handle_t =
|
||||
std::function<void(const subgraph_t& /*hitted pattern*/, Graph*)>;
|
||||
|
||||
void operator()(Graph* graph, handle_t handler);
|
||||
|
||||
const PDPattern& pattern() const { return pattern_; }
|
||||
PDPattern* mutable_pattern() { return &pattern_; }
|
||||
|
||||
private:
|
||||
// Mark the nodes that fits the pattern.
|
||||
bool MarkPDNodesInGraph(const ir::Graph& graph);
|
||||
|
||||
// Detect all the pattern and output the hit records.
|
||||
std::vector<subgraph_t> DetectPatterns();
|
||||
|
||||
// Remove duplicate patterns.
|
||||
void UniquePatterns(std::vector<subgraph_t>* subgraphs);
|
||||
|
||||
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
|
||||
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
|
||||
|
||||
#ifdef PADDLE_WITH_TESTING
|
||||
FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph);
|
||||
FRIEND_TEST(GraphPatternDetecter, DetectPatterns);
|
||||
#endif
|
||||
|
||||
private:
|
||||
using hit_rcd_t =
|
||||
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
|
||||
PDPattern pattern_;
|
||||
std::vector<hit_rcd_t> marked_records_;
|
||||
std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,172 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void BuildGraph(Graph* g) {
|
||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
||||
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
|
||||
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
|
||||
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
|
||||
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
|
||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
||||
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
|
||||
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
|
||||
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
|
||||
|
||||
// o1->v1->o2
|
||||
o1->outputs.push_back(v1);
|
||||
o2->inputs.push_back(v1);
|
||||
v1->inputs.push_back(o1);
|
||||
v1->outputs.push_back(o2);
|
||||
// o2->v2->o3
|
||||
// o2->v2->o4
|
||||
o2->outputs.push_back(v2);
|
||||
o3->inputs.push_back(v2);
|
||||
o4->inputs.push_back(v2);
|
||||
v2->inputs.push_back(o2);
|
||||
v2->outputs.push_back(o3);
|
||||
v2->outputs.push_back(o4);
|
||||
// o2->v3->o5
|
||||
o2->outputs.push_back(v3);
|
||||
o5->inputs.push_back(v3);
|
||||
v3->inputs.push_back(o2);
|
||||
v3->outputs.push_back(o5);
|
||||
// o3-v4->o5
|
||||
o3->outputs.push_back(v4);
|
||||
o5->inputs.push_back(v4);
|
||||
v4->inputs.push_back(o3);
|
||||
v4->outputs.push_back(o5);
|
||||
}
|
||||
|
||||
TEST(PDPattern, NewNode) {
|
||||
PDPattern x;
|
||||
auto* n = x.NewNode([](Node* x) { return true; });
|
||||
ASSERT_TRUE(n);
|
||||
ASSERT_EQ(x.nodes_.size(), 1UL);
|
||||
}
|
||||
|
||||
TEST(PDPattern, AddEdge) {
|
||||
PDPattern x;
|
||||
auto* a = x.NewNode([](Node* x) { return true; });
|
||||
auto* b = x.NewNode([](Node* x) { return true; });
|
||||
ASSERT_TRUE(a);
|
||||
ASSERT_TRUE(b);
|
||||
x.AddEdge(a, b);
|
||||
ASSERT_EQ(x.nodes_.size(), 2UL);
|
||||
ASSERT_EQ(x.edges_.size(), 1UL);
|
||||
ASSERT_EQ(x.edges_.front().first, a);
|
||||
ASSERT_EQ(x.edges_.front().second, b);
|
||||
|
||||
ASSERT_EQ(x.nodes().size(), 2UL);
|
||||
ASSERT_EQ(x.edges().size(), 1UL);
|
||||
ASSERT_EQ(x.edges().front().first, a);
|
||||
ASSERT_EQ(x.edges().front().second, b);
|
||||
}
|
||||
|
||||
TEST(GraphPatternDetecter, MarkPDNodesInGraph) {
|
||||
GraphPatternDetecter x;
|
||||
// mark o2, o3, v2
|
||||
|
||||
// The pattern is a graph:
|
||||
// o2(a node named o2) -> v2(a node named v2)
|
||||
// v2 -> o3(a node named o3)
|
||||
auto* o2 = x.pattern_.NewNode([](Node* node) {
|
||||
// The teller can be any condition, such as op type, or variable's shape.
|
||||
return node && node->Name() == "op2" && node->IsOp();
|
||||
});
|
||||
auto* o3 = x.pattern_.NewNode([](Node* node) {
|
||||
// The teller can be any condition, such as op type, or variable's shape.
|
||||
return node && node->Name() == "op3" && node->IsOp();
|
||||
});
|
||||
auto* v2 = x.pattern_.NewNode([](Node* node) {
|
||||
// The teller can be any condition, such as op type, or variable's shape.
|
||||
return node && node->Name() == "var2" && node->IsVar();
|
||||
});
|
||||
|
||||
ASSERT_FALSE(o2->Tell(nullptr));
|
||||
ASSERT_FALSE(o3->Tell(nullptr));
|
||||
ASSERT_FALSE(v2->Tell(nullptr));
|
||||
|
||||
x.pattern_.AddEdge(o2, v2);
|
||||
x.pattern_.AddEdge(v2, o3);
|
||||
|
||||
ASSERT_EQ(x.pattern_.edges().size(), 2UL);
|
||||
ASSERT_EQ(x.pattern_.edges()[0].first, o2);
|
||||
ASSERT_EQ(x.pattern_.edges()[0].second, v2);
|
||||
ASSERT_EQ(x.pattern_.edges()[1].first, v2);
|
||||
ASSERT_EQ(x.pattern_.edges()[1].second, o3);
|
||||
|
||||
ProgramDesc program;
|
||||
Graph graph(program);
|
||||
BuildGraph(&graph);
|
||||
|
||||
x.MarkPDNodesInGraph(graph);
|
||||
|
||||
ASSERT_EQ(x.pdnodes2nodes_.size(), 3UL);
|
||||
|
||||
auto subgraphs = x.DetectPatterns();
|
||||
ASSERT_EQ(subgraphs.size(), 1UL);
|
||||
}
|
||||
|
||||
TEST(GraphPatternDetecter, MultiSubgraph) {
|
||||
ProgramDesc program;
|
||||
Graph graph(program);
|
||||
BuildGraph(&graph);
|
||||
|
||||
GraphPatternDetecter x;
|
||||
|
||||
// The pattern is a graph:
|
||||
// op -> var
|
||||
auto* any_op = x.mutable_pattern()->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3");
|
||||
},
|
||||
"OP0");
|
||||
auto* any_var = x.mutable_pattern()->NewNode(
|
||||
[](Node* node) { return node->IsVar(); }, "VAR");
|
||||
auto* any_op1 = x.mutable_pattern()->NewNode(
|
||||
[](Node* node) { return node->IsOp(); }, "OP1");
|
||||
|
||||
x.mutable_pattern()->AddEdge(any_op, any_var);
|
||||
x.mutable_pattern()->AddEdge(any_var, any_op1);
|
||||
|
||||
int count = 0;
|
||||
GraphPatternDetecter::handle_t handle = [&](
|
||||
const GraphPatternDetecter::subgraph_t& s, Graph* g) {
|
||||
LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> "
|
||||
<< s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name();
|
||||
count++;
|
||||
};
|
||||
|
||||
x(&graph, handle);
|
||||
|
||||
// 1. Detect op3 -> var4 -> op5
|
||||
// 2. Detect op2 -> var2 -> op3
|
||||
// 3. Detect op2 -> var2 -> op4
|
||||
// 4. Detect op2 -> var3 -> op5
|
||||
// But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2
|
||||
ASSERT_GE(count, 1UL);
|
||||
ASSERT_LE(count, 2UL);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
//
|
||||
// NodesDFSIterator
|
||||
//
|
||||
NodesDFSIterator::NodesDFSIterator(const std::vector<Node *> &source) {
|
||||
for (auto *x : source) stack_.push(x);
|
||||
}
|
||||
|
||||
NodesDFSIterator::NodesDFSIterator(NodesDFSIterator &&other) noexcept
|
||||
: stack_(std::move(other.stack_)),
|
||||
visited_(std::move(other.visited_)) {}
|
||||
|
||||
NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other)
|
||||
: stack_(other.stack_), visited_(other.visited_) {}
|
||||
|
||||
Node &NodesDFSIterator::operator*() {
|
||||
PADDLE_ENFORCE(!stack_.empty());
|
||||
return *stack_.top();
|
||||
}
|
||||
|
||||
NodesDFSIterator &NodesDFSIterator::operator++() {
|
||||
PADDLE_ENFORCE(!stack_.empty(), "the iterator exceeds range");
|
||||
visited_.insert(stack_.top());
|
||||
auto *cur = stack_.top();
|
||||
stack_.pop();
|
||||
for (auto *x : cur->outputs) {
|
||||
if (!visited_.count(x)) {
|
||||
stack_.push(x);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
bool NodesDFSIterator::operator==(const NodesDFSIterator &other) {
|
||||
if (stack_.empty()) return other.stack_.empty();
|
||||
if ((!stack_.empty()) && (!other.stack_.empty())) {
|
||||
return stack_.top() == other.stack_.top();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
NodesDFSIterator &NodesDFSIterator::operator=(const NodesDFSIterator &other) {
|
||||
stack_ = other.stack_;
|
||||
visited_ = other.visited_;
|
||||
return *this;
|
||||
}
|
||||
Node *NodesDFSIterator::operator->() { return stack_.top(); }
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stack>
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/node.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
template <typename IteratorT>
|
||||
class iterator_range {
|
||||
IteratorT begin_, end_;
|
||||
|
||||
public:
|
||||
template <typename Container>
|
||||
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
|
||||
|
||||
iterator_range(const IteratorT &begin, const IteratorT &end)
|
||||
: begin_(begin), end_(end) {}
|
||||
|
||||
const IteratorT &begin() const { return begin_; }
|
||||
const IteratorT &end() const { return end_; }
|
||||
};
|
||||
|
||||
// DFS iterator on nodes.
|
||||
struct NodesDFSIterator
|
||||
: public std::iterator<std::forward_iterator_tag, Node *> {
|
||||
NodesDFSIterator() = default;
|
||||
explicit NodesDFSIterator(const std::vector<Node *> &source);
|
||||
NodesDFSIterator(NodesDFSIterator &&other) noexcept;
|
||||
NodesDFSIterator(const NodesDFSIterator &other);
|
||||
|
||||
Node &operator*();
|
||||
NodesDFSIterator &operator++();
|
||||
// TODO(Superjomn) current implementation just compare the first
|
||||
// element, need to compare the graph and all the elements in the queue and
|
||||
// set.
|
||||
NodesDFSIterator &operator=(const NodesDFSIterator &other);
|
||||
bool operator==(const NodesDFSIterator &other);
|
||||
bool operator!=(const NodesDFSIterator &other) { return !(*this == other); }
|
||||
Node *operator->();
|
||||
|
||||
private:
|
||||
std::stack<Node *> stack_;
|
||||
std::unordered_set<Node *> visited_;
|
||||
};
|
||||
|
||||
/*
|
||||
* GraphTraits contains some graph traversal algorithms.
|
||||
*
|
||||
* Usage:
|
||||
*
|
||||
*/
|
||||
struct GraphTraits {
|
||||
static iterator_range<NodesDFSIterator> DFS(const Graph &g) {
|
||||
auto start_points = ExtractStartPoints(g);
|
||||
NodesDFSIterator x(start_points);
|
||||
return iterator_range<NodesDFSIterator>(NodesDFSIterator(start_points),
|
||||
NodesDFSIterator());
|
||||
}
|
||||
|
||||
private:
|
||||
// The nodes those have no input will be treated as start points.
|
||||
static std::vector<Node *> ExtractStartPoints(const Graph &g) {
|
||||
std::vector<Node *> result;
|
||||
for (auto *node : g.Nodes()) {
|
||||
if (node->inputs.empty()) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,46 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pthread.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
struct RWLock {
|
||||
RWLock() { pthread_rwlock_init(&lock_, nullptr); }
|
||||
|
||||
~RWLock() { pthread_rwlock_destroy(&lock_); }
|
||||
|
||||
void RDLock() {
|
||||
PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0,
|
||||
"acquire read lock failed");
|
||||
}
|
||||
|
||||
void WRLock() {
|
||||
PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0,
|
||||
"acquire write lock failed");
|
||||
}
|
||||
|
||||
void UNLock() {
|
||||
PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed");
|
||||
}
|
||||
|
||||
private:
|
||||
pthread_rwlock_t lock_;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue