Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/refine_parallel_executor
commit
fcbf19bf93
File diff suppressed because it is too large
Load Diff
@ -1,2 +1,17 @@
|
|||||||
cc_library(analysis SRCS dot.cc node.cc node.h)
|
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
|
||||||
|
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc
|
||||||
|
DEPS paddle_fluid)
|
||||||
cc_test(test_node SRCS node_tester.cc DEPS analysis)
|
cc_test(test_node SRCS node_tester.cc DEPS analysis)
|
||||||
|
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
|
||||||
|
|
||||||
|
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
|
||||||
|
|
||||||
|
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid
|
||||||
|
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
|
||||||
|
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec)
|
||||||
|
|
||||||
|
cc_test(test_subgraph_splitter
|
||||||
|
SRCS subgraph_splitter_tester.cc
|
||||||
|
DEPS analysis paddle_fluid tensor
|
||||||
|
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
|
||||||
|
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec)
|
||||||
|
@ -0,0 +1,205 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/dot.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
// It is a better idea that the inputs and outputs of this graph is set manully
|
||||||
|
// before, but there must be a Pass that helps to prune the unnecessary ops that
|
||||||
|
// do not contribute to the given targets, so in this pass, analysis and get the
|
||||||
|
// inputs and outputs is OK.
|
||||||
|
void DataFlowGraph::Build() {
|
||||||
|
inputs.clear();
|
||||||
|
outputs.clear();
|
||||||
|
std::unordered_set<Node *> ins;
|
||||||
|
std::unordered_set<Node *> outs;
|
||||||
|
for (auto &node : nodes.nodes()) {
|
||||||
|
for (auto *in : node->inlinks) {
|
||||||
|
ins.insert(in);
|
||||||
|
}
|
||||||
|
for (auto *out : node->outlinks) {
|
||||||
|
outs.insert(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The nodes that in ins but not in outs is the graph's inputs
|
||||||
|
// similarly, the nodes that in outs but not in ins is the graphs' outputs
|
||||||
|
for (auto *in : ins) {
|
||||||
|
if (!outs.count(in)) {
|
||||||
|
inputs.push_back(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto *out : outs) {
|
||||||
|
if (!outs.count(out)) {
|
||||||
|
outputs.push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DataFlowGraph::DotString() const {
|
||||||
|
Dot dot;
|
||||||
|
|
||||||
|
// Add nodes
|
||||||
|
for (size_t i = 0; i < nodes.size(); i++) {
|
||||||
|
const Node &node = nodes.Get(i);
|
||||||
|
switch (node.type()) {
|
||||||
|
case Node::Type::kValue:
|
||||||
|
dot.AddNode(node.repr(), node.dot_attrs());
|
||||||
|
break;
|
||||||
|
case Node::Type::kFunction:
|
||||||
|
dot.AddNode(node.repr(), node.dot_attrs());
|
||||||
|
break;
|
||||||
|
case Node::Type::kFunctionBlock:
|
||||||
|
dot.AddNode(node.repr(), node.dot_attrs());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add edges
|
||||||
|
for (size_t i = 0; i < nodes.size(); i++) {
|
||||||
|
const Node &node = nodes.Get(i);
|
||||||
|
for (auto &in : node.inlinks) {
|
||||||
|
dot.AddEdge(in->repr(), node.repr(), {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dot.Build();
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// NodesBFSIterator
|
||||||
|
//
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
|
||||||
|
const std::vector<Node *> &source)
|
||||||
|
: queue_(source.begin(), source.end()) {}
|
||||||
|
|
||||||
|
// GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
|
||||||
|
// GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept
|
||||||
|
// : queue_(std::move(other.queue_)),
|
||||||
|
// visited_(std::move(other.visited_)) {}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
|
||||||
|
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other)
|
||||||
|
: queue_(other.queue_), visited_(other.visited_) {}
|
||||||
|
|
||||||
|
Node &GraphTraits<DataFlowGraph>::NodesBFSIterator::operator*() {
|
||||||
|
PADDLE_ENFORCE(!queue_.empty());
|
||||||
|
return *queue_.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
Node *GraphTraits<DataFlowGraph>::NodesBFSIterator::operator->() {
|
||||||
|
PADDLE_ENFORCE(!queue_.empty());
|
||||||
|
return queue_.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesBFSIterator &
|
||||||
|
GraphTraits<DataFlowGraph>::NodesBFSIterator::operator=(
|
||||||
|
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
|
||||||
|
queue_ = other.queue_;
|
||||||
|
visited_ = other.visited_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesBFSIterator
|
||||||
|
&GraphTraits<DataFlowGraph>::NodesBFSIterator::operator++() {
|
||||||
|
PADDLE_ENFORCE(!queue_.empty());
|
||||||
|
auto *cur = queue_.front();
|
||||||
|
visited_.insert(cur);
|
||||||
|
queue_.pop_front();
|
||||||
|
for (auto *output : cur->outlinks) {
|
||||||
|
if (!visited_.count(output)) {
|
||||||
|
queue_.push_back(output);
|
||||||
|
visited_.insert(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
|
||||||
|
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
|
||||||
|
if (queue_.empty()) return other.queue_.empty();
|
||||||
|
if ((!queue_.empty()) && (!other.queue_.empty())) {
|
||||||
|
return queue_.front() == other.queue_.front() &&
|
||||||
|
visited_.size() == other.visited_.size(); // here need to check the
|
||||||
|
// equality of queue and
|
||||||
|
// visited. Just a light but week implementation.
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// NodesDFSIterator
|
||||||
|
//
|
||||||
|
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
|
||||||
|
const std::vector<Node *> &source) {
|
||||||
|
for (auto *x : source) stack_.push(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
|
||||||
|
// GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept
|
||||||
|
// : stack_(std::move(other.stack_)),
|
||||||
|
// visited_(std::move(other.visited_)) {}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
|
||||||
|
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other)
|
||||||
|
: stack_(other.stack_), visited_(other.visited_) {}
|
||||||
|
|
||||||
|
Node &GraphTraits<DataFlowGraph>::NodesDFSIterator::operator*() {
|
||||||
|
PADDLE_ENFORCE(!stack_.empty());
|
||||||
|
return *stack_.top();
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesDFSIterator
|
||||||
|
&GraphTraits<DataFlowGraph>::NodesDFSIterator::operator++() {
|
||||||
|
if (stack_.empty()) return *this;
|
||||||
|
visited_.insert(stack_.top());
|
||||||
|
auto *cur = stack_.top();
|
||||||
|
stack_.pop();
|
||||||
|
for (auto *x : cur->outlinks) {
|
||||||
|
if (!visited_.count(x)) {
|
||||||
|
stack_.push(x);
|
||||||
|
visited_.insert(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
bool GraphTraits<DataFlowGraph>::NodesDFSIterator::operator==(
|
||||||
|
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
|
||||||
|
if (stack_.empty()) return other.stack_.empty();
|
||||||
|
if ((!stack_.empty()) && (!other.stack_.empty())) {
|
||||||
|
return stack_.top() == other.stack_.top();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph>::NodesDFSIterator &
|
||||||
|
GraphTraits<DataFlowGraph>::NodesDFSIterator::operator=(
|
||||||
|
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
|
||||||
|
stack_ = other.stack_;
|
||||||
|
visited_ = other.visited_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
|
||||||
|
return stack_.top();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,159 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Data flow graph is an pass that build the basic graph. It contains a graph
|
||||||
|
* and the iterators that enable the iteration over the graph.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <stack>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/graph_traits.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/node.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* DataFlowGraph - A container of Value and Function Nodes.
|
||||||
|
*/
|
||||||
|
struct DataFlowGraph {
|
||||||
|
NodeMap nodes;
|
||||||
|
std::vector<Node *> inputs;
|
||||||
|
std::vector<Node *> outputs;
|
||||||
|
|
||||||
|
// Extract inputs and outputs of the graph.
|
||||||
|
void Build();
|
||||||
|
|
||||||
|
// Output a DOT graph file for debug.
|
||||||
|
std::string DotString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* An graph trait help to traverse the graph using BFS.
|
||||||
|
* The BFS start from a graph's inputs, the graph should be fully-connected, so
|
||||||
|
* that the iterator can reach the end.
|
||||||
|
*/
|
||||||
|
template <>
|
||||||
|
struct GraphTraits<DataFlowGraph> {
|
||||||
|
// BFS iterator on nodes.
|
||||||
|
struct NodesBFSIterator
|
||||||
|
: public std::iterator<std::forward_iterator_tag, Node *> {
|
||||||
|
NodesBFSIterator() = default;
|
||||||
|
explicit NodesBFSIterator(const std::vector<Node *> &source);
|
||||||
|
// NodesBFSIterator(NodesBFSIterator &&other) noexcept;
|
||||||
|
// NOTE Heavy to use.
|
||||||
|
NodesBFSIterator(const NodesBFSIterator &other);
|
||||||
|
|
||||||
|
Node &operator*();
|
||||||
|
NodesBFSIterator &operator++();
|
||||||
|
Node *operator->();
|
||||||
|
// TODO(Superjomn) current implementation just compare the first
|
||||||
|
// element, need to compare the graph and all the elements in the queue and
|
||||||
|
// set.
|
||||||
|
NodesBFSIterator &operator=(const NodesBFSIterator &other);
|
||||||
|
bool operator==(const NodesBFSIterator &other);
|
||||||
|
bool operator!=(const NodesBFSIterator &other) { return !(*this == other); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::deque<Node *> queue_;
|
||||||
|
std::unordered_set<Node *> visited_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// DFS iterator on nodes.
|
||||||
|
struct NodesDFSIterator
|
||||||
|
: public std::iterator<std::forward_iterator_tag, Node *> {
|
||||||
|
NodesDFSIterator() = default;
|
||||||
|
explicit NodesDFSIterator(const std::vector<Node *> &source);
|
||||||
|
// NodesDFSIterator(NodesDFSIterator &&other) noexcept;
|
||||||
|
NodesDFSIterator(const NodesDFSIterator &other);
|
||||||
|
|
||||||
|
Node &operator*();
|
||||||
|
NodesDFSIterator &operator++();
|
||||||
|
// TODO(Superjomn) current implementation just compare the first
|
||||||
|
// element, need to compare the graph and all the elements in the queue and
|
||||||
|
// set.
|
||||||
|
NodesDFSIterator &operator=(const NodesDFSIterator &other);
|
||||||
|
bool operator==(const NodesDFSIterator &other);
|
||||||
|
bool operator!=(const NodesDFSIterator &other) { return !(*this == other); }
|
||||||
|
Node *operator->();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::stack<Node *> stack_;
|
||||||
|
std::unordered_set<Node *> visited_;
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
|
||||||
|
|
||||||
|
// default use BFS to visit the nodes.
|
||||||
|
iterator_range<NodesBFSIterator> nodes() {
|
||||||
|
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
|
||||||
|
}
|
||||||
|
iterator_range<NodesBFSIterator> nodes_in_BFS() {
|
||||||
|
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
|
||||||
|
}
|
||||||
|
iterator_range<NodesDFSIterator> nodes_in_DFS() {
|
||||||
|
return iterator_range<NodesDFSIterator>(nodes_dfs_begin(), nodes_dfs_end());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
NodesBFSIterator nodes_bfs_begin() {
|
||||||
|
return NodesBFSIterator(graph_->inputs);
|
||||||
|
}
|
||||||
|
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
|
||||||
|
NodesDFSIterator nodes_dfs_begin() {
|
||||||
|
return NodesDFSIterator(graph_->inputs);
|
||||||
|
}
|
||||||
|
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataFlowGraph *graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract the inputs and outputs of a graph. The inputs and outputs of a
|
||||||
|
// sub-graph is the inputs nodes and output nodes that doesn't inside the
|
||||||
|
// sub-graph.
|
||||||
|
std::pair<
|
||||||
|
std::vector<Node *>,
|
||||||
|
std::vector<
|
||||||
|
Node *>> static ExtractInputAndOutputOfSubGraph(std::vector<Node *>
|
||||||
|
&graph) {
|
||||||
|
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
|
||||||
|
std::unordered_set<Node *> inputs;
|
||||||
|
std::unordered_set<Node *> outputs;
|
||||||
|
for (auto &node : graph) {
|
||||||
|
for (auto *in : node->inlinks) {
|
||||||
|
if (!nodes.count(in) && in->type() == Node::Type::kValue) {
|
||||||
|
inputs.insert(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto *out : node->outlinks) {
|
||||||
|
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
|
||||||
|
outputs.insert(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
|
||||||
|
std::vector<Node *>(outputs.begin(), outputs.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
TEST(DataFlowGraph, BFS) {
|
||||||
|
auto desc = LoadProgramDesc();
|
||||||
|
auto dfg = ProgramDescToDFG(desc);
|
||||||
|
dfg.Build();
|
||||||
|
|
||||||
|
for (auto* in : dfg.inputs) {
|
||||||
|
LOG(INFO) << "inputs: " << in->name() << " "
|
||||||
|
<< static_cast<int>(in->type());
|
||||||
|
}
|
||||||
|
for (auto* out : dfg.outputs) {
|
||||||
|
LOG(INFO) << "outputs: " << out->name() << " "
|
||||||
|
<< static_cast<int>(out->type());
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphTraits<DataFlowGraph> trait(&dfg);
|
||||||
|
auto nodes = trait.nodes();
|
||||||
|
int count = 0;
|
||||||
|
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
|
||||||
|
LOG(INFO) << "visiting " << it->name();
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(count, dfg.nodes.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DataFlowGraph, DFS) {
|
||||||
|
auto desc = LoadProgramDesc();
|
||||||
|
auto dfg = ProgramDescToDFG(desc);
|
||||||
|
dfg.Build();
|
||||||
|
GraphTraits<DataFlowGraph> trait(&dfg);
|
||||||
|
auto nodes = trait.nodes_in_DFS();
|
||||||
|
int count = 0;
|
||||||
|
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
|
||||||
|
LOG(INFO) << "visiting " << it->name();
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(count, dfg.nodes.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,49 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
|
||||||
|
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include <google/protobuf/text_format.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "paddle/fluid/framework/executor.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||||
|
#include "paddle/fluid/inference/io.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
TEST_F(DFG_Tester, Test) {
|
||||||
|
framework::proto::ProgramDesc new_desc;
|
||||||
|
DataFlowGraph graph;
|
||||||
|
|
||||||
|
FluidToDataFlowGraphPass pass0;
|
||||||
|
DataFlowGraphToFluidPass pass1;
|
||||||
|
pass0.Initialize(desc);
|
||||||
|
pass1.Initialize(&new_desc);
|
||||||
|
|
||||||
|
pass0.Run(&graph);
|
||||||
|
pass1.Run(&graph);
|
||||||
|
|
||||||
|
pass0.Finalize();
|
||||||
|
pass1.Finalize();
|
||||||
|
|
||||||
|
LOG(INFO) << graph.nodes.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // analysis
|
||||||
|
} // inference
|
||||||
|
} // paddle
|
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
FluidToDataFlowGraphPass::FluidToDataFlowGraphPass() {}
|
||||||
|
|
||||||
|
bool FluidToDataFlowGraphPass::Initialize() { return Pass::Initialize(); }
|
||||||
|
|
||||||
|
bool FluidToDataFlowGraphPass::Initialize(
|
||||||
|
const framework::proto::ProgramDesc &desc) {
|
||||||
|
desc_ = &desc;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); }
|
||||||
|
|
||||||
|
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
|
||||||
|
// insert vars
|
||||||
|
std::unordered_map<std::string, size_t> var2id;
|
||||||
|
auto &main_block = desc_->blocks(framework::kRootBlockIndex);
|
||||||
|
for (int i = 0; i < main_block.vars_size(); i++) {
|
||||||
|
const auto &var = main_block.vars(i);
|
||||||
|
auto *v = graph->nodes.Create(Node::Type::kValue);
|
||||||
|
v->SetName(var.name());
|
||||||
|
v->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&var)));
|
||||||
|
var2id[var.name()] = v->id();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < main_block.ops_size(); i++) {
|
||||||
|
const auto &op = main_block.ops(i);
|
||||||
|
auto *o = graph->nodes.Create(Node::Type::kFunction);
|
||||||
|
o->SetName(op.type());
|
||||||
|
static_cast<Function *>(o)->SetFuncType(op.type());
|
||||||
|
// Link to the original protobuf message's memory, make it easier to
|
||||||
|
// generate from a data flow graph to fluid ProgramDesc.
|
||||||
|
o->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&op)));
|
||||||
|
// set inputs and outputs
|
||||||
|
// TODO(Superjomn) make sure the InputNames is the real variable name.
|
||||||
|
for (int j = 0; j < op.inputs_size(); j++) {
|
||||||
|
auto &in_var = op.inputs(j);
|
||||||
|
for (int k = 0; k < in_var.arguments_size(); k++) {
|
||||||
|
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
|
||||||
|
in->outlinks.push_back(o);
|
||||||
|
o->inlinks.push_back(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < op.outputs_size(); j++) {
|
||||||
|
auto &out_var = op.outputs(j);
|
||||||
|
for (int k = 0; k < out_var.arguments_size(); k++) {
|
||||||
|
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
|
||||||
|
out->inlinks.push_back(o);
|
||||||
|
o->outlinks.push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Analysis and extract the inputs and outputs of this graph.
|
||||||
|
graph->Build();
|
||||||
|
}
|
||||||
|
|
||||||
|
Pass *FluidToDataFlowGraphPass::CreatePrinterPass(
|
||||||
|
std::ostream &os, const std::string &banner) const {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This file implements the transformation from data flow graph to fluid
|
||||||
|
* ProgramDesc.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/program_desc.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Transform a FluidDesc to a data flow graph.
|
||||||
|
*/
|
||||||
|
class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
|
||||||
|
public:
|
||||||
|
FluidToDataFlowGraphPass();
|
||||||
|
bool Initialize() override;
|
||||||
|
bool Initialize(const framework::proto::ProgramDesc &desc) override;
|
||||||
|
bool Finalize() override;
|
||||||
|
|
||||||
|
void Run(DataFlowGraph *graph) override;
|
||||||
|
|
||||||
|
Pass *CreatePrinterPass(std::ostream &os,
|
||||||
|
const std::string &banner) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
framework::proto::ProgramDesc const *desc_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
TEST_F(DFG_Tester, Init) {
|
||||||
|
FluidToDataFlowGraphPass pass;
|
||||||
|
pass.Initialize();
|
||||||
|
pass.Initialize(desc);
|
||||||
|
DataFlowGraph graph;
|
||||||
|
pass.Run(&graph);
|
||||||
|
ASSERT_GT(graph.nodes.size(), 0);
|
||||||
|
pass.Finalize();
|
||||||
|
LOG(INFO) << '\n' << graph.DotString();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // analysis
|
||||||
|
} // inference
|
||||||
|
} // paddle
|
@ -0,0 +1,15 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/graph_traits.h"
|
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This file defines the GraphTraits<X> template class that should be specified
|
||||||
|
* by classes that want to be iteratable by generic graph iterators.
|
||||||
|
*
|
||||||
|
* This file also defines the marker class Inverse that is used to iterate over
|
||||||
|
* graphs in a graph defined, inverse ordering...
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This class should be specialized by different graph types...
|
||||||
|
* That's why the base class is empty.
|
||||||
|
*/
|
||||||
|
template <typename GraphType>
|
||||||
|
struct GraphTraits {
|
||||||
|
// using NodesBFSIterator = xxx
|
||||||
|
|
||||||
|
// NodesBFSIterator nodes_begin();
|
||||||
|
// NodesBFSIterator nodes_end();
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Inverse - This class is used as a marker class to tell the graph iterator to
|
||||||
|
* iterate in a graph defined Inverse order.
|
||||||
|
*/
|
||||||
|
template <typename GraphType>
|
||||||
|
struct Inverse {
|
||||||
|
const GraphType &graph;
|
||||||
|
|
||||||
|
explicit Inverse(const GraphType &graph) : graph(graph) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Provide a partial specialization of GraphTraits so that the inverse of an
|
||||||
|
* inverse turns into the original graph.
|
||||||
|
*/
|
||||||
|
template <typename GraphType>
|
||||||
|
struct GraphTraits<Inverse<Inverse<GraphType>>> : GraphTraits<GraphType> {};
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -1,74 +1,107 @@
|
|||||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License. */
|
limitations under the License. */
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "paddle/fluid/platform/enforce.h"
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
namespace paddle {
|
namespace paddle {
|
||||||
namespace inference {
|
namespace inference {
|
||||||
namespace analysis {
|
namespace analysis {
|
||||||
|
|
||||||
template <typename IteratorT>
|
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
|
||||||
class iterator_range {
|
/*
|
||||||
IteratorT begin_, end_;
|
* Map typeid to representation.
|
||||||
|
*/
|
||||||
public:
|
struct DataTypeNamer {
|
||||||
template <typename Container>
|
static const DataTypeNamer &Global() {
|
||||||
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
|
static auto *x = new DataTypeNamer();
|
||||||
|
return *x;
|
||||||
iterator_range(const IteratorT &begin, const IteratorT &end)
|
}
|
||||||
: begin_(begin), end_(end) {}
|
|
||||||
|
template <typename T>
|
||||||
const IteratorT &begin() const { return begin_; }
|
const std::string &repr() const {
|
||||||
const IteratorT &end() const { return end_; }
|
auto x = typeid(T).hash_code();
|
||||||
};
|
PADDLE_ENFORCE(dic_.count(x), "unknown type for representation");
|
||||||
|
return dic_.at(x);
|
||||||
/*
|
}
|
||||||
* An registry helper class, with its records keeps the order they registers.
|
|
||||||
*/
|
const std::string &repr(size_t &hash) const {
|
||||||
template <typename T>
|
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
|
||||||
class OrderedRegistry {
|
return dic_.at(hash);
|
||||||
public:
|
}
|
||||||
T *Register(const std::string &name, T *x) {
|
|
||||||
PADDLE_ENFORCE(!dic_.count(name));
|
private:
|
||||||
dic_[name] = data_.size();
|
DataTypeNamer() {
|
||||||
data_.emplace_back(std::unique_ptr<T>(x));
|
SET_TYPE(int);
|
||||||
return data_.back().get();
|
SET_TYPE(bool);
|
||||||
}
|
SET_TYPE(float);
|
||||||
|
}
|
||||||
T *Lookup(const std::string &name) {
|
|
||||||
auto it = dic_.find(name);
|
std::unordered_map<decltype(typeid(int).hash_code()), std::string> dic_;
|
||||||
if (it == dic_.end()) return nullptr;
|
};
|
||||||
return data_[it->second].get();
|
#undef SET_TYPE
|
||||||
}
|
|
||||||
|
template <typename IteratorT>
|
||||||
protected:
|
class iterator_range {
|
||||||
std::unordered_map<std::string, int> dic_;
|
IteratorT begin_, end_;
|
||||||
std::vector<std::unique_ptr<T>> data_;
|
|
||||||
};
|
public:
|
||||||
|
template <typename Container>
|
||||||
} // namespace analysis
|
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
|
||||||
} // namespace inference
|
|
||||||
} // namespace paddle
|
iterator_range(const IteratorT &begin, const IteratorT &end)
|
||||||
|
: begin_(begin), end_(end) {}
|
||||||
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
|
|
||||||
\
|
const IteratorT &begin() const { return begin_; }
|
||||||
type__(const type__ &) = delete; \
|
const IteratorT &end() const { return end_; }
|
||||||
\
|
};
|
||||||
void operator=(const type__ &) = delete;
|
|
||||||
|
/*
|
||||||
|
* An registry helper class, with its records keeps the order they registers.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
class OrderedRegistry {
|
||||||
|
public:
|
||||||
|
T *Register(const std::string &name, T *x) {
|
||||||
|
PADDLE_ENFORCE(!dic_.count(name));
|
||||||
|
dic_[name] = data_.size();
|
||||||
|
data_.emplace_back(std::unique_ptr<T>(x));
|
||||||
|
return data_.back().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
T *Lookup(const std::string &name) {
|
||||||
|
auto it = dic_.find(name);
|
||||||
|
if (it == dic_.end()) return nullptr;
|
||||||
|
return data_[it->second].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unordered_map<std::string, int> dic_;
|
||||||
|
std::vector<std::unique_ptr<T>> data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
|
||||||
|
\
|
||||||
|
type__(const type__ &) = delete; \
|
||||||
|
\
|
||||||
|
void operator=(const type__ &) = delete;
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/pass.h"
|
@ -0,0 +1,90 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include <iosfwd>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/framework.pb.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/helper.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/node.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
class Pass {
|
||||||
|
public:
|
||||||
|
Pass() = default;
|
||||||
|
virtual ~Pass() {}
|
||||||
|
// Virtual method overridden by subclasses to do only necessary initialization
|
||||||
|
// before any pass is run.
|
||||||
|
virtual bool Initialize() { return false; }
|
||||||
|
// There is some passes such as FlowToDataFlowGraphPass that needs a
|
||||||
|
// ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it
|
||||||
|
// only couple with the proto file.
|
||||||
|
virtual bool Initialize(const framework::proto::ProgramDesc &desc) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// There are some Passes such as DataFlowGraphToFluidPass that will output a
|
||||||
|
// ProgramDesc.
|
||||||
|
virtual bool Initialize(framework::proto::ProgramDesc *desc) { return false; }
|
||||||
|
|
||||||
|
// Virtual method overriden by subclasses to do any necessary clean up after
|
||||||
|
// all passes have run.
|
||||||
|
virtual bool Finalize() { return false; }
|
||||||
|
|
||||||
|
// Get a Pass appropriate to print the Node this pass operates on.
|
||||||
|
virtual Pass *CreatePrinterPass(std::ostream &os,
|
||||||
|
const std::string &banner) const = 0;
|
||||||
|
|
||||||
|
// Run on a single Node.
|
||||||
|
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
|
||||||
|
// Run on a single Function.
|
||||||
|
virtual void Run(Function *x) { LOG(FATAL) << "not valid"; }
|
||||||
|
// Run on a single FunctionBlock.
|
||||||
|
virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; }
|
||||||
|
// Run on a single DataFlowGraph.
|
||||||
|
virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// NodePass process on any Node types.
|
||||||
|
class NodePass : public Pass {
|
||||||
|
public:
|
||||||
|
virtual void Run(Node *node) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// NodePass process on any Function node types.
|
||||||
|
class FunctionPass : public Pass {
|
||||||
|
public:
|
||||||
|
virtual void Run(Function *node) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// NodePass process on any FunctionBlock node types.
|
||||||
|
class FunctionBlockPass : public Pass {
|
||||||
|
public:
|
||||||
|
virtual void Run(FunctionBlock *node) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// GraphPass processes on any GraphType.
|
||||||
|
class DataFlowGraphPass : public Pass {
|
||||||
|
public:
|
||||||
|
virtual void Run(DataFlowGraph *graph) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,154 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
const char *SubGraphSplitter::kMarkerAttrName =
|
||||||
|
"_sub_graph_splitter_inside_sub_graph";
|
||||||
|
|
||||||
|
std::vector<std::vector<Node *>> SubGraphSplitter::operator()() {
|
||||||
|
MarkNodesInsideSubGraph();
|
||||||
|
return ExtractSubGraphs();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the output variables inside a subgraph with the func.
|
||||||
|
inline void MarkOutLinksInSubGraph(const Function *func) {
|
||||||
|
for (auto *var : func->outlinks) {
|
||||||
|
var->attr(SubGraphSplitter::kMarkerAttrName).Bool() = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SubGraphSplitter::MarkNodesInsideSubGraph() {
|
||||||
|
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
|
||||||
|
if (node_inside_subgraph_teller_(&node)) {
|
||||||
|
node.attr(kMarkerAttrName).Bool() = true;
|
||||||
|
if (node.type() == Node::Type::kFunction) {
|
||||||
|
// If a function is inside the sub-graph, mark all the output variables
|
||||||
|
// to be inside too, so that two marked functions will be inside a same
|
||||||
|
// sub-graph, lets take a example: A_function->var->B_function, if
|
||||||
|
// A_function is marked, var should also be marked, so that B_function
|
||||||
|
// will be in the same sub-graph with A_function if B_function is
|
||||||
|
// marked.
|
||||||
|
MarkOutLinksInSubGraph(static_cast<const Function *>(&node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *kUnionFindParent = "_sub_graph_splitter_union_find_parent_";
|
||||||
|
|
||||||
|
// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
|
||||||
|
// a's output is node b, that is a and b is in the same sub-graph. The UF
|
||||||
|
// algorithm will group them to the same cluster.
|
||||||
|
using node_map_t = std::unordered_map<int, Node *>;
|
||||||
|
// Find the ancestor id of a node.
|
||||||
|
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
|
||||||
|
int tmp = id;
|
||||||
|
do {
|
||||||
|
tmp = node_map.at(tmp)->attr(kUnionFindParent).Int32();
|
||||||
|
} while (node_map.at(tmp)->attr(kUnionFindParent).Int32() != tmp);
|
||||||
|
return tmp;
|
||||||
|
}
|
||||||
|
// Make this two node share the same ancestor.
|
||||||
|
// TODO(Superjom) bad performance, make a balanced tree latter.
|
||||||
|
void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
|
||||||
|
int a_ancestor = UnionFindGetAncestor(node_map, a);
|
||||||
|
int b_ancestor = UnionFindGetAncestor(node_map, b);
|
||||||
|
node_map.at(b_ancestor)->attr(kUnionFindParent).Int32() = a_ancestor;
|
||||||
|
node_map.at(a)->attr(kUnionFindParent).Int32() = a_ancestor;
|
||||||
|
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
|
||||||
|
std::vector<Node *> marked_nodes;
|
||||||
|
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
|
||||||
|
if (node.attr(kMarkerAttrName).Bool()) {
|
||||||
|
marked_nodes.push_back(&node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// extract sub-graphs in the marked node set, use Union Find algorithm.
|
||||||
|
node_map_t node_map; // id to ptr
|
||||||
|
for (auto *n : marked_nodes) {
|
||||||
|
// n's parent == n.id means it is the ancestor
|
||||||
|
n->attr(kUnionFindParent).Int32() = n->id();
|
||||||
|
node_map[n->id()] = n;
|
||||||
|
}
|
||||||
|
std::unordered_set<Node *> visited;
|
||||||
|
for (auto *n : marked_nodes) {
|
||||||
|
for (auto *out : n->outlinks) {
|
||||||
|
if (node_map.count(out->id())) {
|
||||||
|
UnionFindCombine(node_map, n->id(), out->id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<int /*ancestor*/, std::vector<Node *>> clusters;
|
||||||
|
for (auto *n : marked_nodes) {
|
||||||
|
if (n->type() == Node::Type::kFunction) {
|
||||||
|
clusters[UnionFindGetAncestor(node_map,
|
||||||
|
n->attr(kUnionFindParent).Int32())]
|
||||||
|
.push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::vector<Node *>> result;
|
||||||
|
std::for_each(clusters.begin(), clusters.end(),
|
||||||
|
[&](const decltype(clusters)::value_type &it) {
|
||||||
|
result.push_back(it.second);
|
||||||
|
});
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
|
||||||
|
|
||||||
|
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
|
||||||
|
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
|
||||||
|
for (auto &subgraph : subgraphs) {
|
||||||
|
// replace this sub-graph with the first node. Two steps: 1. Create a Block
|
||||||
|
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph
|
||||||
|
// as deleted. 3. Replace the deleted node with the new Block Node.
|
||||||
|
auto *block_node = graph_->nodes.Create(Node::Type::kFunctionBlock);
|
||||||
|
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
|
||||||
|
block_node->inlinks = std::move(io.first);
|
||||||
|
block_node->outlinks = std::move(io.second);
|
||||||
|
for (auto *node : subgraph) {
|
||||||
|
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
|
||||||
|
// pass.
|
||||||
|
node->SetDeleted();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<Node *, Node *>
|
||||||
|
delelte_node_map; // deleted node to BlockNode
|
||||||
|
for (auto *n : block_node->inlinks) {
|
||||||
|
n->inlinks.clear();
|
||||||
|
}
|
||||||
|
for (auto *n : block_node->outlinks) {
|
||||||
|
n->outlinks.clear();
|
||||||
|
}
|
||||||
|
for (auto *n : block_node->inlinks) {
|
||||||
|
n->outlinks.push_back(block_node);
|
||||||
|
}
|
||||||
|
for (auto *n : block_node->outlinks) {
|
||||||
|
n->inlinks.push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This file defines the the class to partition a graph.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/node.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Detect the nodes in a sub-graph that meet some conditions. This class doesn't
|
||||||
|
* modify the graph.
|
||||||
|
*/
|
||||||
|
class SubGraphSplitter {
|
||||||
|
public:
|
||||||
|
static const char *kMarkerAttrName;
|
||||||
|
// Tell whether a node is inside a sub-graph.
|
||||||
|
using NodeInsideSubgraphTeller = std::function<bool(const Node *)>;
|
||||||
|
|
||||||
|
SubGraphSplitter(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller)
|
||||||
|
: graph_(graph), node_inside_subgraph_teller_(teller) {}
|
||||||
|
|
||||||
|
std::vector<std::vector<Node *>> operator()();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Mark the nodes inside the accepted sub-graph using
|
||||||
|
// node_inside_subgraph_teller.
|
||||||
|
void MarkNodesInsideSubGraph();
|
||||||
|
|
||||||
|
// Merge the marked nodes into sub-graphs and return the sub-graphs.
|
||||||
|
std::vector<std::vector<Node *>> ExtractSubGraphs();
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataFlowGraph *graph_;
|
||||||
|
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* SubGraphFuse - Replace some nodes with the sub-graph node they are inside. To
|
||||||
|
* some extent, the TensorRT engine is just a fusion op for a model.
|
||||||
|
*/
|
||||||
|
class SubGraphFuse {
|
||||||
|
public:
|
||||||
|
using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller;
|
||||||
|
|
||||||
|
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller)
|
||||||
|
: graph_(graph), node_inside_subgraph_teller_(teller) {}
|
||||||
|
|
||||||
|
// The main method which run all the logic.
|
||||||
|
void operator()();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Remove the nodes inside sub-graphs and replace with the SubGraphNode.
|
||||||
|
void ReplaceNodesWithSubGraphs();
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataFlowGraph *graph_;
|
||||||
|
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
TEST_F(DFG_Tester, Split) {
|
||||||
|
auto desc = LoadProgramDesc();
|
||||||
|
auto dfg = ProgramDescToDFG(desc);
|
||||||
|
LOG(INFO) << "spliter\n" << dfg.DotString();
|
||||||
|
|
||||||
|
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
|
||||||
|
if (node->type() != Node::Type::kFunction) return false;
|
||||||
|
const auto* func = static_cast<const Function*>(node);
|
||||||
|
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
|
||||||
|
func->func_type() == "conv2d" || func->func_type() == "mul" ||
|
||||||
|
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
|
||||||
|
LOG(INFO) << "sub-graph marked " << node->repr();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
ASSERT_GT(dfg.nodes.size(), 5UL);
|
||||||
|
|
||||||
|
auto subgraphs = SubGraphSplitter(&dfg, teller)();
|
||||||
|
|
||||||
|
// Check the number of the marked nodes.
|
||||||
|
int marked_nodes = 0;
|
||||||
|
for (auto& node : dfg.nodes.nodes()) {
|
||||||
|
if (node->IsFunction() &&
|
||||||
|
node->attr(SubGraphSplitter::kMarkerAttrName).Bool()) {
|
||||||
|
++marked_nodes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(marked_nodes, 6);
|
||||||
|
|
||||||
|
// For human debug.
|
||||||
|
for (auto& subgraph : subgraphs) {
|
||||||
|
LOG(INFO) << "subgraph size " << subgraph.size();
|
||||||
|
for (auto* node : subgraph) {
|
||||||
|
LOG(INFO) << "node " << node->repr();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(subgraphs.size(), 1UL);
|
||||||
|
// The last sub-graph has 5 Functions.
|
||||||
|
ASSERT_EQ(subgraphs.back().size(), 6UL);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "paddle/fluid/framework/executor.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
|
||||||
|
#include "paddle/fluid/inference/analysis/ut_helper.h"
|
||||||
|
#include "paddle/fluid/inference/io.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace analysis {
|
||||||
|
|
||||||
|
DEFINE_string(inference_model_dir, "", "inference test model dir");
|
||||||
|
|
||||||
|
static framework::proto::ProgramDesc LoadProgramDesc(
|
||||||
|
const std::string& model_dir = FLAGS_inference_model_dir) {
|
||||||
|
// TODO(Superjomn) update latter.
|
||||||
|
auto place = paddle::platform::CPUPlace();
|
||||||
|
auto executor = paddle::framework::Executor(place);
|
||||||
|
auto* scope = new paddle::framework::Scope();
|
||||||
|
auto program = Load(&executor, scope, model_dir);
|
||||||
|
return *program->Proto();
|
||||||
|
}
|
||||||
|
|
||||||
|
static DataFlowGraph ProgramDescToDFG(
|
||||||
|
const framework::proto::ProgramDesc& desc) {
|
||||||
|
DataFlowGraph graph;
|
||||||
|
FluidToDataFlowGraphPass pass;
|
||||||
|
pass.Initialize(desc);
|
||||||
|
pass.Run(&graph);
|
||||||
|
pass.Finalize();
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
class DFG_Tester : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override { desc = LoadProgramDesc(FLAGS_inference_model_dir); }
|
||||||
|
|
||||||
|
framework::proto::ProgramDesc desc;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace analysis
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue