Inference analysis/init data flow graph analysis (#10776)
Add the demo of subgraph splittershanyi15-patch-3
parent
a9f9fbadd9
commit
1153144fbb
@ -1,2 +1,17 @@
|
||||
cc_library(analysis SRCS dot.cc node.cc node.h)
|
||||
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
|
||||
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc
|
||||
DEPS paddle_fluid)
|
||||
cc_test(test_node SRCS node_tester.cc DEPS analysis)
|
||||
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
|
||||
|
||||
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
|
||||
|
||||
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid
|
||||
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
|
||||
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec)
|
||||
|
||||
cc_test(test_subgraph_splitter
|
||||
SRCS subgraph_splitter_tester.cc
|
||||
DEPS analysis paddle_fluid tensor
|
||||
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
|
||||
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec)
|
||||
|
@ -0,0 +1,205 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
#include "paddle/fluid/inference/analysis/dot.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
// It is a better idea that the inputs and outputs of this graph is set manully
|
||||
// before, but there must be a Pass that helps to prune the unnecessary ops that
|
||||
// do not contribute to the given targets, so in this pass, analysis and get the
|
||||
// inputs and outputs is OK.
|
||||
void DataFlowGraph::Build() {
|
||||
inputs.clear();
|
||||
outputs.clear();
|
||||
std::unordered_set<Node *> ins;
|
||||
std::unordered_set<Node *> outs;
|
||||
for (auto &node : nodes.nodes()) {
|
||||
for (auto *in : node->inlinks) {
|
||||
ins.insert(in);
|
||||
}
|
||||
for (auto *out : node->outlinks) {
|
||||
outs.insert(out);
|
||||
}
|
||||
}
|
||||
|
||||
// The nodes that in ins but not in outs is the graph's inputs
|
||||
// similarly, the nodes that in outs but not in ins is the graphs' outputs
|
||||
for (auto *in : ins) {
|
||||
if (!outs.count(in)) {
|
||||
inputs.push_back(in);
|
||||
}
|
||||
}
|
||||
for (auto *out : outs) {
|
||||
if (!outs.count(out)) {
|
||||
outputs.push_back(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string DataFlowGraph::DotString() const {
|
||||
Dot dot;
|
||||
|
||||
// Add nodes
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
const Node &node = nodes.Get(i);
|
||||
switch (node.type()) {
|
||||
case Node::Type::kValue:
|
||||
dot.AddNode(node.repr(), node.dot_attrs());
|
||||
break;
|
||||
case Node::Type::kFunction:
|
||||
dot.AddNode(node.repr(), node.dot_attrs());
|
||||
break;
|
||||
case Node::Type::kFunctionBlock:
|
||||
dot.AddNode(node.repr(), node.dot_attrs());
|
||||
break;
|
||||
default:
|
||||
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
|
||||
}
|
||||
}
|
||||
|
||||
// Add edges
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
const Node &node = nodes.Get(i);
|
||||
for (auto &in : node.inlinks) {
|
||||
dot.AddEdge(in->repr(), node.repr(), {});
|
||||
}
|
||||
}
|
||||
return dot.Build();
|
||||
}
|
||||
|
||||
//
|
||||
// NodesBFSIterator
|
||||
//
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
|
||||
const std::vector<Node *> &source)
|
||||
: queue_(source.begin(), source.end()) {}
|
||||
|
||||
// GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
|
||||
// GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept
|
||||
// : queue_(std::move(other.queue_)),
|
||||
// visited_(std::move(other.visited_)) {}
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
|
||||
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other)
|
||||
: queue_(other.queue_), visited_(other.visited_) {}
|
||||
|
||||
Node &GraphTraits<DataFlowGraph>::NodesBFSIterator::operator*() {
|
||||
PADDLE_ENFORCE(!queue_.empty());
|
||||
return *queue_.front();
|
||||
}
|
||||
|
||||
Node *GraphTraits<DataFlowGraph>::NodesBFSIterator::operator->() {
|
||||
PADDLE_ENFORCE(!queue_.empty());
|
||||
return queue_.front();
|
||||
}
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesBFSIterator &
|
||||
GraphTraits<DataFlowGraph>::NodesBFSIterator::operator=(
|
||||
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
|
||||
queue_ = other.queue_;
|
||||
visited_ = other.visited_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesBFSIterator
|
||||
&GraphTraits<DataFlowGraph>::NodesBFSIterator::operator++() {
|
||||
PADDLE_ENFORCE(!queue_.empty());
|
||||
auto *cur = queue_.front();
|
||||
visited_.insert(cur);
|
||||
queue_.pop_front();
|
||||
for (auto *output : cur->outlinks) {
|
||||
if (!visited_.count(output)) {
|
||||
queue_.push_back(output);
|
||||
visited_.insert(output);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
|
||||
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
|
||||
if (queue_.empty()) return other.queue_.empty();
|
||||
if ((!queue_.empty()) && (!other.queue_.empty())) {
|
||||
return queue_.front() == other.queue_.front() &&
|
||||
visited_.size() == other.visited_.size(); // here need to check the
|
||||
// equality of queue and
|
||||
// visited. Just a light but week implementation.
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// NodesDFSIterator
|
||||
//
|
||||
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
|
||||
const std::vector<Node *> &source) {
|
||||
for (auto *x : source) stack_.push(x);
|
||||
}
|
||||
|
||||
// GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
|
||||
// GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept
|
||||
// : stack_(std::move(other.stack_)),
|
||||
// visited_(std::move(other.visited_)) {}
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
|
||||
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other)
|
||||
: stack_(other.stack_), visited_(other.visited_) {}
|
||||
|
||||
Node &GraphTraits<DataFlowGraph>::NodesDFSIterator::operator*() {
|
||||
PADDLE_ENFORCE(!stack_.empty());
|
||||
return *stack_.top();
|
||||
}
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesDFSIterator
|
||||
&GraphTraits<DataFlowGraph>::NodesDFSIterator::operator++() {
|
||||
if (stack_.empty()) return *this;
|
||||
visited_.insert(stack_.top());
|
||||
auto *cur = stack_.top();
|
||||
stack_.pop();
|
||||
for (auto *x : cur->outlinks) {
|
||||
if (!visited_.count(x)) {
|
||||
stack_.push(x);
|
||||
visited_.insert(x);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
bool GraphTraits<DataFlowGraph>::NodesDFSIterator::operator==(
|
||||
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
|
||||
if (stack_.empty()) return other.stack_.empty();
|
||||
if ((!stack_.empty()) && (!other.stack_.empty())) {
|
||||
return stack_.top() == other.stack_.top();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
GraphTraits<DataFlowGraph>::NodesDFSIterator &
|
||||
GraphTraits<DataFlowGraph>::NodesDFSIterator::operator=(
|
||||
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
|
||||
stack_ = other.stack_;
|
||||
visited_ = other.visited_;
|
||||
return *this;
|
||||
}
|
||||
Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
|
||||
return stack_.top();
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,159 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
/*
|
||||
* Data flow graph is an pass that build the basic graph. It contains a graph
|
||||
* and the iterators that enable the iteration over the graph.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "paddle/fluid/inference/analysis/graph_traits.h"
|
||||
#include "paddle/fluid/inference/analysis/node.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
/*
|
||||
* DataFlowGraph - A container of Value and Function Nodes.
|
||||
*/
|
||||
struct DataFlowGraph {
|
||||
NodeMap nodes;
|
||||
std::vector<Node *> inputs;
|
||||
std::vector<Node *> outputs;
|
||||
|
||||
// Extract inputs and outputs of the graph.
|
||||
void Build();
|
||||
|
||||
// Output a DOT graph file for debug.
|
||||
std::string DotString() const;
|
||||
};
|
||||
|
||||
/*
|
||||
* An graph trait help to traverse the graph using BFS.
|
||||
* The BFS start from a graph's inputs, the graph should be fully-connected, so
|
||||
* that the iterator can reach the end.
|
||||
*/
|
||||
template <>
|
||||
struct GraphTraits<DataFlowGraph> {
|
||||
// BFS iterator on nodes.
|
||||
struct NodesBFSIterator
|
||||
: public std::iterator<std::forward_iterator_tag, Node *> {
|
||||
NodesBFSIterator() = default;
|
||||
explicit NodesBFSIterator(const std::vector<Node *> &source);
|
||||
// NodesBFSIterator(NodesBFSIterator &&other) noexcept;
|
||||
// NOTE Heavy to use.
|
||||
NodesBFSIterator(const NodesBFSIterator &other);
|
||||
|
||||
Node &operator*();
|
||||
NodesBFSIterator &operator++();
|
||||
Node *operator->();
|
||||
// TODO(Superjomn) current implementation just compare the first
|
||||
// element, need to compare the graph and all the elements in the queue and
|
||||
// set.
|
||||
NodesBFSIterator &operator=(const NodesBFSIterator &other);
|
||||
bool operator==(const NodesBFSIterator &other);
|
||||
bool operator!=(const NodesBFSIterator &other) { return !(*this == other); }
|
||||
|
||||
private:
|
||||
std::deque<Node *> queue_;
|
||||
std::unordered_set<Node *> visited_;
|
||||
};
|
||||
|
||||
// DFS iterator on nodes.
|
||||
struct NodesDFSIterator
|
||||
: public std::iterator<std::forward_iterator_tag, Node *> {
|
||||
NodesDFSIterator() = default;
|
||||
explicit NodesDFSIterator(const std::vector<Node *> &source);
|
||||
// NodesDFSIterator(NodesDFSIterator &&other) noexcept;
|
||||
NodesDFSIterator(const NodesDFSIterator &other);
|
||||
|
||||
Node &operator*();
|
||||
NodesDFSIterator &operator++();
|
||||
// TODO(Superjomn) current implementation just compare the first
|
||||
// element, need to compare the graph and all the elements in the queue and
|
||||
// set.
|
||||
NodesDFSIterator &operator=(const NodesDFSIterator &other);
|
||||
bool operator==(const NodesDFSIterator &other);
|
||||
bool operator!=(const NodesDFSIterator &other) { return !(*this == other); }
|
||||
Node *operator->();
|
||||
|
||||
private:
|
||||
std::stack<Node *> stack_;
|
||||
std::unordered_set<Node *> visited_;
|
||||
};
|
||||
|
||||
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
|
||||
|
||||
// default use BFS to visit the nodes.
|
||||
iterator_range<NodesBFSIterator> nodes() {
|
||||
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
|
||||
}
|
||||
iterator_range<NodesBFSIterator> nodes_in_BFS() {
|
||||
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
|
||||
}
|
||||
iterator_range<NodesDFSIterator> nodes_in_DFS() {
|
||||
return iterator_range<NodesDFSIterator>(nodes_dfs_begin(), nodes_dfs_end());
|
||||
}
|
||||
|
||||
private:
|
||||
NodesBFSIterator nodes_bfs_begin() {
|
||||
return NodesBFSIterator(graph_->inputs);
|
||||
}
|
||||
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
|
||||
NodesDFSIterator nodes_dfs_begin() {
|
||||
return NodesDFSIterator(graph_->inputs);
|
||||
}
|
||||
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
|
||||
|
||||
private:
|
||||
DataFlowGraph *graph_;
|
||||
};
|
||||
|
||||
// Extract the inputs and outputs of a graph. The inputs and outputs of a
|
||||
// sub-graph is the inputs nodes and output nodes that doesn't inside the
|
||||
// sub-graph.
|
||||
std::pair<
|
||||
std::vector<Node *>,
|
||||
std::vector<
|
||||
Node *>> static ExtractInputAndOutputOfSubGraph(std::vector<Node *>
|
||||
&graph) {
|
||||
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
|
||||
std::unordered_set<Node *> inputs;
|
||||
std::unordered_set<Node *> outputs;
|
||||
for (auto &node : graph) {
|
||||
for (auto *in : node->inlinks) {
|
||||
if (!nodes.count(in) && in->type() == Node::Type::kValue) {
|
||||
inputs.insert(in);
|
||||
}
|
||||
}
|
||||
for (auto *out : node->outlinks) {
|
||||
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
|
||||
outputs.insert(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
|
||||
std::vector<Node *>(outputs.begin(), outputs.end()));
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
TEST(DataFlowGraph, BFS) {
|
||||
auto desc = LoadProgramDesc();
|
||||
auto dfg = ProgramDescToDFG(desc);
|
||||
dfg.Build();
|
||||
|
||||
for (auto* in : dfg.inputs) {
|
||||
LOG(INFO) << "inputs: " << in->name() << " "
|
||||
<< static_cast<int>(in->type());
|
||||
}
|
||||
for (auto* out : dfg.outputs) {
|
||||
LOG(INFO) << "outputs: " << out->name() << " "
|
||||
<< static_cast<int>(out->type());
|
||||
}
|
||||
|
||||
GraphTraits<DataFlowGraph> trait(&dfg);
|
||||
auto nodes = trait.nodes();
|
||||
int count = 0;
|
||||
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
|
||||
LOG(INFO) << "visiting " << it->name();
|
||||
++count;
|
||||
}
|
||||
ASSERT_EQ(count, dfg.nodes.size());
|
||||
}
|
||||
|
||||
TEST(DataFlowGraph, DFS) {
|
||||
auto desc = LoadProgramDesc();
|
||||
auto dfg = ProgramDescToDFG(desc);
|
||||
dfg.Build();
|
||||
GraphTraits<DataFlowGraph> trait(&dfg);
|
||||
auto nodes = trait.nodes_in_DFS();
|
||||
int count = 0;
|
||||
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
|
||||
LOG(INFO) << "visiting " << it->name();
|
||||
++count;
|
||||
}
|
||||
ASSERT_EQ(count, dfg.nodes.size());
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||
#include "paddle/fluid/inference/io.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
TEST_F(DFG_Tester, Test) {
|
||||
framework::proto::ProgramDesc new_desc;
|
||||
DataFlowGraph graph;
|
||||
|
||||
FluidToDataFlowGraphPass pass0;
|
||||
DataFlowGraphToFluidPass pass1;
|
||||
pass0.Initialize(desc);
|
||||
pass1.Initialize(&new_desc);
|
||||
|
||||
pass0.Run(&graph);
|
||||
pass1.Run(&graph);
|
||||
|
||||
pass0.Finalize();
|
||||
pass1.Finalize();
|
||||
|
||||
LOG(INFO) << graph.nodes.size();
|
||||
}
|
||||
|
||||
} // analysis
|
||||
} // inference
|
||||
} // paddle
|
@ -0,0 +1,83 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
FluidToDataFlowGraphPass::FluidToDataFlowGraphPass() {}
|
||||
|
||||
bool FluidToDataFlowGraphPass::Initialize() { return Pass::Initialize(); }
|
||||
|
||||
bool FluidToDataFlowGraphPass::Initialize(
|
||||
const framework::proto::ProgramDesc &desc) {
|
||||
desc_ = &desc;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); }
|
||||
|
||||
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
|
||||
// insert vars
|
||||
std::unordered_map<std::string, size_t> var2id;
|
||||
auto &main_block = desc_->blocks(framework::kRootBlockIndex);
|
||||
for (int i = 0; i < main_block.vars_size(); i++) {
|
||||
const auto &var = main_block.vars(i);
|
||||
auto *v = graph->nodes.Create(Node::Type::kValue);
|
||||
v->SetName(var.name());
|
||||
v->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&var)));
|
||||
var2id[var.name()] = v->id();
|
||||
}
|
||||
for (int i = 0; i < main_block.ops_size(); i++) {
|
||||
const auto &op = main_block.ops(i);
|
||||
auto *o = graph->nodes.Create(Node::Type::kFunction);
|
||||
o->SetName(op.type());
|
||||
static_cast<Function *>(o)->SetFuncType(op.type());
|
||||
// Link to the original protobuf message's memory, make it easier to
|
||||
// generate from a data flow graph to fluid ProgramDesc.
|
||||
o->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&op)));
|
||||
// set inputs and outputs
|
||||
// TODO(Superjomn) make sure the InputNames is the real variable name.
|
||||
for (int j = 0; j < op.inputs_size(); j++) {
|
||||
auto &in_var = op.inputs(j);
|
||||
for (int k = 0; k < in_var.arguments_size(); k++) {
|
||||
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
|
||||
in->outlinks.push_back(o);
|
||||
o->inlinks.push_back(in);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < op.outputs_size(); j++) {
|
||||
auto &out_var = op.outputs(j);
|
||||
for (int k = 0; k < out_var.arguments_size(); k++) {
|
||||
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
|
||||
out->inlinks.push_back(o);
|
||||
o->outlinks.push_back(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Analysis and extract the inputs and outputs of this graph.
|
||||
graph->Build();
|
||||
}
|
||||
|
||||
Pass *FluidToDataFlowGraphPass::CreatePrinterPass(
|
||||
std::ostream &os, const std::string &banner) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
/*
|
||||
* This file implements the transformation from data flow graph to fluid
|
||||
* ProgramDesc.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
#include "paddle/fluid/inference/analysis/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
/*
|
||||
* Transform a FluidDesc to a data flow graph.
|
||||
*/
|
||||
class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
|
||||
public:
|
||||
FluidToDataFlowGraphPass();
|
||||
bool Initialize() override;
|
||||
bool Initialize(const framework::proto::ProgramDesc &desc) override;
|
||||
bool Finalize() override;
|
||||
|
||||
void Run(DataFlowGraph *graph) override;
|
||||
|
||||
Pass *CreatePrinterPass(std::ostream &os,
|
||||
const std::string &banner) const override;
|
||||
|
||||
private:
|
||||
framework::proto::ProgramDesc const *desc_;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
TEST_F(DFG_Tester, Init) {
|
||||
FluidToDataFlowGraphPass pass;
|
||||
pass.Initialize();
|
||||
pass.Initialize(desc);
|
||||
DataFlowGraph graph;
|
||||
pass.Run(&graph);
|
||||
ASSERT_GT(graph.nodes.size(), 0);
|
||||
pass.Finalize();
|
||||
LOG(INFO) << '\n' << graph.DotString();
|
||||
}
|
||||
|
||||
} // analysis
|
||||
} // inference
|
||||
} // paddle
|
@ -0,0 +1,15 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/graph_traits.h"
|
@ -0,0 +1,63 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
/*
|
||||
* This file defines the GraphTraits<X> template class that should be specified
|
||||
* by classes that want to be iteratable by generic graph iterators.
|
||||
*
|
||||
* This file also defines the marker class Inverse that is used to iterate over
|
||||
* graphs in a graph defined, inverse ordering...
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/inference/analysis/helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
/*
|
||||
* This class should be specialized by different graph types...
|
||||
* That's why the base class is empty.
|
||||
*/
|
||||
template <typename GraphType>
|
||||
struct GraphTraits {
|
||||
// using NodesBFSIterator = xxx
|
||||
|
||||
// NodesBFSIterator nodes_begin();
|
||||
// NodesBFSIterator nodes_end();
|
||||
};
|
||||
|
||||
/*
|
||||
* Inverse - This class is used as a marker class to tell the graph iterator to
|
||||
* iterate in a graph defined Inverse order.
|
||||
*/
|
||||
template <typename GraphType>
|
||||
struct Inverse {
|
||||
const GraphType &graph;
|
||||
|
||||
explicit Inverse(const GraphType &graph) : graph(graph) {}
|
||||
};
|
||||
|
||||
/*
|
||||
* Provide a partial specialization of GraphTraits so that the inverse of an
|
||||
* inverse turns into the original graph.
|
||||
*/
|
||||
template <typename GraphType>
|
||||
struct GraphTraits<Inverse<Inverse<GraphType>>> : GraphTraits<GraphType> {};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -1,74 +1,107 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
template <typename IteratorT>
|
||||
class iterator_range {
|
||||
IteratorT begin_, end_;
|
||||
|
||||
public:
|
||||
template <typename Container>
|
||||
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
|
||||
|
||||
iterator_range(const IteratorT &begin, const IteratorT &end)
|
||||
: begin_(begin), end_(end) {}
|
||||
|
||||
const IteratorT &begin() const { return begin_; }
|
||||
const IteratorT &end() const { return end_; }
|
||||
};
|
||||
|
||||
/*
|
||||
* An registry helper class, with its records keeps the order they registers.
|
||||
*/
|
||||
template <typename T>
|
||||
class OrderedRegistry {
|
||||
public:
|
||||
T *Register(const std::string &name, T *x) {
|
||||
PADDLE_ENFORCE(!dic_.count(name));
|
||||
dic_[name] = data_.size();
|
||||
data_.emplace_back(std::unique_ptr<T>(x));
|
||||
return data_.back().get();
|
||||
}
|
||||
|
||||
T *Lookup(const std::string &name) {
|
||||
auto it = dic_.find(name);
|
||||
if (it == dic_.end()) return nullptr;
|
||||
return data_[it->second].get();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, int> dic_;
|
||||
std::vector<std::unique_ptr<T>> data_;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
|
||||
\
|
||||
type__(const type__ &) = delete; \
|
||||
\
|
||||
void operator=(const type__ &) = delete;
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
|
||||
/*
|
||||
* Map typeid to representation.
|
||||
*/
|
||||
struct DataTypeNamer {
|
||||
static const DataTypeNamer &Global() {
|
||||
static auto *x = new DataTypeNamer();
|
||||
return *x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const std::string &repr() const {
|
||||
auto x = typeid(T).hash_code();
|
||||
PADDLE_ENFORCE(dic_.count(x), "unknown type for representation");
|
||||
return dic_.at(x);
|
||||
}
|
||||
|
||||
const std::string &repr(size_t &hash) const {
|
||||
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
|
||||
return dic_.at(hash);
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypeNamer() {
|
||||
SET_TYPE(int);
|
||||
SET_TYPE(bool);
|
||||
SET_TYPE(float);
|
||||
}
|
||||
|
||||
std::unordered_map<decltype(typeid(int).hash_code()), std::string> dic_;
|
||||
};
|
||||
#undef SET_TYPE
|
||||
|
||||
template <typename IteratorT>
|
||||
class iterator_range {
|
||||
IteratorT begin_, end_;
|
||||
|
||||
public:
|
||||
template <typename Container>
|
||||
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
|
||||
|
||||
iterator_range(const IteratorT &begin, const IteratorT &end)
|
||||
: begin_(begin), end_(end) {}
|
||||
|
||||
const IteratorT &begin() const { return begin_; }
|
||||
const IteratorT &end() const { return end_; }
|
||||
};
|
||||
|
||||
/*
|
||||
* An registry helper class, with its records keeps the order they registers.
|
||||
*/
|
||||
template <typename T>
|
||||
class OrderedRegistry {
|
||||
public:
|
||||
T *Register(const std::string &name, T *x) {
|
||||
PADDLE_ENFORCE(!dic_.count(name));
|
||||
dic_[name] = data_.size();
|
||||
data_.emplace_back(std::unique_ptr<T>(x));
|
||||
return data_.back().get();
|
||||
}
|
||||
|
||||
T *Lookup(const std::string &name) {
|
||||
auto it = dic_.find(name);
|
||||
if (it == dic_.end()) return nullptr;
|
||||
return data_[it->second].get();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, int> dic_;
|
||||
std::vector<std::unique_ptr<T>> data_;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
|
||||
\
|
||||
type__(const type__ &) = delete; \
|
||||
\
|
||||
void operator=(const type__ &) = delete;
|
||||
|
@ -0,0 +1,15 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/pass.h"
|
@ -0,0 +1,90 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <iosfwd>
|
||||
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
#include "paddle/fluid/inference/analysis/helper.h"
|
||||
#include "paddle/fluid/inference/analysis/node.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
class Pass {
|
||||
public:
|
||||
Pass() = default;
|
||||
virtual ~Pass() {}
|
||||
// Virtual method overridden by subclasses to do only necessary initialization
|
||||
// before any pass is run.
|
||||
virtual bool Initialize() { return false; }
|
||||
// There is some passes such as FlowToDataFlowGraphPass that needs a
|
||||
// ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it
|
||||
// only couple with the proto file.
|
||||
virtual bool Initialize(const framework::proto::ProgramDesc &desc) {
|
||||
return false;
|
||||
}
|
||||
// There are some Passes such as DataFlowGraphToFluidPass that will output a
|
||||
// ProgramDesc.
|
||||
virtual bool Initialize(framework::proto::ProgramDesc *desc) { return false; }
|
||||
|
||||
// Virtual method overriden by subclasses to do any necessary clean up after
|
||||
// all passes have run.
|
||||
virtual bool Finalize() { return false; }
|
||||
|
||||
// Get a Pass appropriate to print the Node this pass operates on.
|
||||
virtual Pass *CreatePrinterPass(std::ostream &os,
|
||||
const std::string &banner) const = 0;
|
||||
|
||||
// Run on a single Node.
|
||||
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
|
||||
// Run on a single Function.
|
||||
virtual void Run(Function *x) { LOG(FATAL) << "not valid"; }
|
||||
// Run on a single FunctionBlock.
|
||||
virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; }
|
||||
// Run on a single DataFlowGraph.
|
||||
virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; }
|
||||
};
|
||||
|
||||
// NodePass process on any Node types.
|
||||
class NodePass : public Pass {
|
||||
public:
|
||||
virtual void Run(Node *node) = 0;
|
||||
};
|
||||
|
||||
// NodePass process on any Function node types.
|
||||
class FunctionPass : public Pass {
|
||||
public:
|
||||
virtual void Run(Function *node) = 0;
|
||||
};
|
||||
|
||||
// NodePass process on any FunctionBlock node types.
|
||||
class FunctionBlockPass : public Pass {
|
||||
public:
|
||||
virtual void Run(FunctionBlock *node) = 0;
|
||||
};
|
||||
|
||||
// GraphPass processes on any GraphType.
|
||||
class DataFlowGraphPass : public Pass {
|
||||
public:
|
||||
virtual void Run(DataFlowGraph *graph) = 0;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,154 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
const char *SubGraphSplitter::kMarkerAttrName =
|
||||
"_sub_graph_splitter_inside_sub_graph";
|
||||
|
||||
std::vector<std::vector<Node *>> SubGraphSplitter::operator()() {
|
||||
MarkNodesInsideSubGraph();
|
||||
return ExtractSubGraphs();
|
||||
}
|
||||
|
||||
// Mark the output variables inside a subgraph with the func.
|
||||
inline void MarkOutLinksInSubGraph(const Function *func) {
|
||||
for (auto *var : func->outlinks) {
|
||||
var->attr(SubGraphSplitter::kMarkerAttrName).Bool() = true;
|
||||
}
|
||||
}
|
||||
|
||||
void SubGraphSplitter::MarkNodesInsideSubGraph() {
|
||||
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
|
||||
if (node_inside_subgraph_teller_(&node)) {
|
||||
node.attr(kMarkerAttrName).Bool() = true;
|
||||
if (node.type() == Node::Type::kFunction) {
|
||||
// If a function is inside the sub-graph, mark all the output variables
|
||||
// to be inside too, so that two marked functions will be inside a same
|
||||
// sub-graph, lets take a example: A_function->var->B_function, if
|
||||
// A_function is marked, var should also be marked, so that B_function
|
||||
// will be in the same sub-graph with A_function if B_function is
|
||||
// marked.
|
||||
MarkOutLinksInSubGraph(static_cast<const Function *>(&node));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char *kUnionFindParent = "_sub_graph_splitter_union_find_parent_";
|
||||
|
||||
// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
|
||||
// a's output is node b, that is a and b is in the same sub-graph. The UF
|
||||
// algorithm will group them to the same cluster.
|
||||
using node_map_t = std::unordered_map<int, Node *>;
|
||||
// Find the ancestor id of a node.
|
||||
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
|
||||
int tmp = id;
|
||||
do {
|
||||
tmp = node_map.at(tmp)->attr(kUnionFindParent).Int32();
|
||||
} while (node_map.at(tmp)->attr(kUnionFindParent).Int32() != tmp);
|
||||
return tmp;
|
||||
}
|
||||
// Make this two node share the same ancestor.
|
||||
// TODO(Superjom) bad performance, make a balanced tree latter.
|
||||
void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
|
||||
int a_ancestor = UnionFindGetAncestor(node_map, a);
|
||||
int b_ancestor = UnionFindGetAncestor(node_map, b);
|
||||
node_map.at(b_ancestor)->attr(kUnionFindParent).Int32() = a_ancestor;
|
||||
node_map.at(a)->attr(kUnionFindParent).Int32() = a_ancestor;
|
||||
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
|
||||
std::vector<Node *> marked_nodes;
|
||||
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
|
||||
if (node.attr(kMarkerAttrName).Bool()) {
|
||||
marked_nodes.push_back(&node);
|
||||
}
|
||||
}
|
||||
// extract sub-graphs in the marked node set, use Union Find algorithm.
|
||||
node_map_t node_map; // id to ptr
|
||||
for (auto *n : marked_nodes) {
|
||||
// n's parent == n.id means it is the ancestor
|
||||
n->attr(kUnionFindParent).Int32() = n->id();
|
||||
node_map[n->id()] = n;
|
||||
}
|
||||
std::unordered_set<Node *> visited;
|
||||
for (auto *n : marked_nodes) {
|
||||
for (auto *out : n->outlinks) {
|
||||
if (node_map.count(out->id())) {
|
||||
UnionFindCombine(node_map, n->id(), out->id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<int /*ancestor*/, std::vector<Node *>> clusters;
|
||||
for (auto *n : marked_nodes) {
|
||||
if (n->type() == Node::Type::kFunction) {
|
||||
clusters[UnionFindGetAncestor(node_map,
|
||||
n->attr(kUnionFindParent).Int32())]
|
||||
.push_back(n);
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<Node *>> result;
|
||||
std::for_each(clusters.begin(), clusters.end(),
|
||||
[&](const decltype(clusters)::value_type &it) {
|
||||
result.push_back(it.second);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
|
||||
|
||||
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
|
||||
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
|
||||
for (auto &subgraph : subgraphs) {
|
||||
// replace this sub-graph with the first node. Two steps: 1. Create a Block
|
||||
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph
|
||||
// as deleted. 3. Replace the deleted node with the new Block Node.
|
||||
auto *block_node = graph_->nodes.Create(Node::Type::kFunctionBlock);
|
||||
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
|
||||
block_node->inlinks = std::move(io.first);
|
||||
block_node->outlinks = std::move(io.second);
|
||||
for (auto *node : subgraph) {
|
||||
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
|
||||
// pass.
|
||||
node->SetDeleted();
|
||||
}
|
||||
|
||||
std::unordered_map<Node *, Node *>
|
||||
delelte_node_map; // deleted node to BlockNode
|
||||
for (auto *n : block_node->inlinks) {
|
||||
n->inlinks.clear();
|
||||
}
|
||||
for (auto *n : block_node->outlinks) {
|
||||
n->outlinks.clear();
|
||||
}
|
||||
for (auto *n : block_node->inlinks) {
|
||||
n->outlinks.push_back(block_node);
|
||||
}
|
||||
for (auto *n : block_node->outlinks) {
|
||||
n->inlinks.push_back(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,81 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
/*
|
||||
* This file defines the the class to partition a graph.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
#include "paddle/fluid/inference/analysis/node.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
/*
|
||||
* Detect the nodes in a sub-graph that meet some conditions. This class doesn't
|
||||
* modify the graph.
|
||||
*/
|
||||
class SubGraphSplitter {
|
||||
public:
|
||||
static const char *kMarkerAttrName;
|
||||
// Tell whether a node is inside a sub-graph.
|
||||
using NodeInsideSubgraphTeller = std::function<bool(const Node *)>;
|
||||
|
||||
SubGraphSplitter(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller)
|
||||
: graph_(graph), node_inside_subgraph_teller_(teller) {}
|
||||
|
||||
std::vector<std::vector<Node *>> operator()();
|
||||
|
||||
protected:
|
||||
// Mark the nodes inside the accepted sub-graph using
|
||||
// node_inside_subgraph_teller.
|
||||
void MarkNodesInsideSubGraph();
|
||||
|
||||
// Merge the marked nodes into sub-graphs and return the sub-graphs.
|
||||
std::vector<std::vector<Node *>> ExtractSubGraphs();
|
||||
|
||||
private:
|
||||
DataFlowGraph *graph_;
|
||||
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
|
||||
};
|
||||
|
||||
/*
|
||||
* SubGraphFuse - Replace some nodes with the sub-graph node they are inside. To
|
||||
* some extent, the TensorRT engine is just a fusion op for a model.
|
||||
*/
|
||||
class SubGraphFuse {
|
||||
public:
|
||||
using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller;
|
||||
|
||||
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller)
|
||||
: graph_(graph), node_inside_subgraph_teller_(teller) {}
|
||||
|
||||
// The main method which run all the logic.
|
||||
void operator()();
|
||||
|
||||
protected:
|
||||
// Remove the nodes inside sub-graphs and replace with the SubGraphNode.
|
||||
void ReplaceNodesWithSubGraphs();
|
||||
|
||||
private:
|
||||
DataFlowGraph *graph_;
|
||||
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,67 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
||||
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
TEST_F(DFG_Tester, Split) {
|
||||
auto desc = LoadProgramDesc();
|
||||
auto dfg = ProgramDescToDFG(desc);
|
||||
LOG(INFO) << "spliter\n" << dfg.DotString();
|
||||
|
||||
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
|
||||
if (node->type() != Node::Type::kFunction) return false;
|
||||
const auto* func = static_cast<const Function*>(node);
|
||||
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
|
||||
func->func_type() == "conv2d" || func->func_type() == "mul" ||
|
||||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
|
||||
LOG(INFO) << "sub-graph marked " << node->repr();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
ASSERT_GT(dfg.nodes.size(), 5UL);
|
||||
|
||||
auto subgraphs = SubGraphSplitter(&dfg, teller)();
|
||||
|
||||
// Check the number of the marked nodes.
|
||||
int marked_nodes = 0;
|
||||
for (auto& node : dfg.nodes.nodes()) {
|
||||
if (node->IsFunction() &&
|
||||
node->attr(SubGraphSplitter::kMarkerAttrName).Bool()) {
|
||||
++marked_nodes;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(marked_nodes, 6);
|
||||
|
||||
// For human debug.
|
||||
for (auto& subgraph : subgraphs) {
|
||||
LOG(INFO) << "subgraph size " << subgraph.size();
|
||||
for (auto* node : subgraph) {
|
||||
LOG(INFO) << "node " << node->repr();
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(subgraphs.size(), 1UL);
|
||||
// The last sub-graph has 5 Functions.
|
||||
ASSERT_EQ(subgraphs.back().size(), 6UL);
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,59 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <gflags/gflags.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||
#include "paddle/fluid/inference/io.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
DEFINE_string(inference_model_dir, "", "inference test model dir");
|
||||
|
||||
static framework::proto::ProgramDesc LoadProgramDesc(
|
||||
const std::string& model_dir = FLAGS_inference_model_dir) {
|
||||
// TODO(Superjomn) update latter.
|
||||
auto place = paddle::platform::CPUPlace();
|
||||
auto executor = paddle::framework::Executor(place);
|
||||
auto* scope = new paddle::framework::Scope();
|
||||
auto program = Load(&executor, scope, model_dir);
|
||||
return *program->Proto();
|
||||
}
|
||||
|
||||
static DataFlowGraph ProgramDescToDFG(
|
||||
const framework::proto::ProgramDesc& desc) {
|
||||
DataFlowGraph graph;
|
||||
FluidToDataFlowGraphPass pass;
|
||||
pass.Initialize(desc);
|
||||
pass.Run(&graph);
|
||||
pass.Finalize();
|
||||
return graph;
|
||||
}
|
||||
|
||||
class DFG_Tester : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override { desc = LoadProgramDesc(FLAGS_inference_model_dir); }
|
||||
|
||||
framework::proto::ProgramDesc desc;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue