createGenDocLib
robot 7 years ago
commit 31c90692f7

@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h"
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace inference {
namespace analysis {
size_t Dot::counter = 0;
} // namespace analysis
} // namespace inference
namespace framework {
namespace ir {
class AttentionLSTMFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
},
"elementwise_add_out");
pattern->AddEdge(mul_parameter_var, mul_op);
pattern->AddEdge(mul_tmp_input_var, mul_op);
pattern->AddEdge(mul_op, mul_out_var);
pattern->AddEdge(mul_out_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
.LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
.LinksTo({elementwise_add_out_var});
}
// Replace the node `from` in the links to `to`
@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete;
GraphPatternDetecter gpd;
GraphPatternDetector gpd;
BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \
@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the

@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {

@ -0,0 +1,126 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::unordered_set<int> fused_ops({// first lstm
13, 15, 16,
// second lstm
23, 25, 26});
pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
"any_node");
std::unordered_set<Node*> marked_nodes;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
marked_nodes.insert(id);
};
gpd(graph.get(), handler);
// Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
int bias, int hidden, int cell, int xx) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(input);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(cell);
GET_NODE(xx);
GET_NODE(lstm);
OpDesc op_desc;
op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
SET_IN(X, input);
SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h);
SET_IN(Bias, bias);
#undef GET_NODE
#undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", false);
auto* op = graph->CreateOpNode(&op_desc);
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
LINK_TO(input_n, op);
LINK_TO(weight_x_n, op);
LINK_TO(weight_h_n, op);
LINK_TO(bias_n, op);
LINK_TO(op, hidden_n);
#undef LINK_TO
return op;
};
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
// remove all the nodes
for (auto* node : marked_nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);

@ -0,0 +1,33 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FCLstmFusePass : public Pass {
public:
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -0,0 +1,44 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kParamScopeAttr[] = "param_scope";
class FusePassBase : public Pass {
public:
void Init(Graph* graph) const { graph_ = graph; }
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
virtual ~FusePassBase() {}
protected:
mutable Graph* graph_;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc));
return AddNode(new ir::Node(var_desc, node_count_++));
}
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc));
return AddNode(new ir::Node(op_desc, node_count_++));
}
// Create a control dependency var that connects 2 operations. The
@ -115,13 +115,14 @@ class Graph {
// TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
return AddNode(
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
}
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type));
return AddNode(new ir::Node(name, type, node_count_++));
}
// Clear all node information of the graph and return the ownership of the
@ -142,12 +143,20 @@ class Graph {
nodes_.erase(node);
}
Node *RetriveNode(int id) {
auto it = id2node_.find(id);
if (it != id2node_.end()) return it->second;
return nullptr;
}
private:
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node);
node_set_.insert(node);
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
id2node_[node->id()] = node;
return node;
}
@ -157,6 +166,8 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_;
std::map<int, Node *> id2node_;
int node_count_{0};
};
bool IsControlDepVar(const ir::Node &var);

@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n].insert(adj_n);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
}
}
}

@ -17,7 +17,7 @@
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h"
@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
name);
}
nodes_.emplace_back(new PDNode(std::move(teller), name));
nodes_.emplace_back(new PDNode(std::move(teller), this, name));
auto* cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
edges_.emplace_back(a, b);
}
void GraphPatternDetecter::operator()(Graph* graph,
GraphPatternDetecter::handle_t handler) {
void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) return;
auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs);
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0;
for (auto& g : subgraphs) {
LOG(INFO) << "optimizing #" << id++ << " subgraph";
handler(g, graph);
}
}
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) {
bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false;
@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
return false;
}
std::vector<GraphPatternDetecter::subgraph_t>
GraphPatternDetecter::DetectPatterns() {
std::vector<GraphPatternDetector::subgraph_t>
GraphPatternDetector::DetectPatterns() {
// Init empty subgraphs.
std::vector<GraphPatternDetecter::subgraph_t> result;
std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups;
PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto* first_pnode = pattern_.edges().front().first;
std::array<std::vector<HitGroup>, 2> bi_records;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result;
for (auto* node : pdnodes2nodes_[first_pnode]) {
HitGroup group;
@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
}
int step = 0;
std::array<std::vector<HitGroup>, 2> bi_records;
bi_records[0] = std::move(init_groups);
// Extend a PDNode to subgraphs by deducing the connection relations defined
@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
auto& pre_groups = bi_records[step % 2];
auto& cur_groups = bi_records[1 - (step++ % 2)];
cur_groups.clear();
if (pre_groups.empty()) break;
// source -> target
for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) {
@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
}
for (auto& group : bi_records[step % 2]) {
GraphPatternDetecter::subgraph_t subgraph;
GraphPatternDetector::subgraph_t subgraph;
for (auto& role : group.roles) {
subgraph.emplace(role.first, role.second);
}
@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
return result;
}
void GraphPatternDetecter::UniquePatterns(
std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) {
void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
if (subgraphs->empty()) return;
std::vector<GraphPatternDetecter::subgraph_t> result;
std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set;
for (auto& g : *subgraphs) {
@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
*subgraphs = result;
}
void GraphPatternDetecter::RemoveOverlappedMatch(
void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t>* subgraphs) {
std::vector<subgraph_t> result;
std::unordered_set<Node*> node_set;
@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
*subgraphs = result;
}
std::string PDPattern::DotString() const {
using inference::analysis::Dot;
Dot dot;
int id = 0;
// Create Nodes
std::unordered_map<PDNode*, std::string> node2dot;
for (const auto& node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
// Create Edges
for (const auto& edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto& src = node2dot.at(edge.first);
auto& trg = node2dot.at(edge.second);
dot.AddEdge(src, trg, {});
}
return dot.Build();
}
PDNode& PDNode::LinksTo(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(this, x);
}
return *this;
}
PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(x, this);
}
return *this;
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -21,12 +21,14 @@
#include <numeric>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle {
namespace framework {
namespace ir {
class PDPattern;
// Some basic torminolygies:
// Some basic terminologies:
// - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`.
@ -36,30 +38,43 @@ namespace ir {
struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar };
PDNode(teller_t&& teller, const std::string& name = "")
: teller_(teller), name_(name) {
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
}
PDNode(PDNode&& other) = default;
std::vector<PDNode*> inlinks;
std::vector<PDNode*> outlinks;
// this link to others
PDNode& LinksTo(const std::vector<PDNode*>& others);
PDNode& LinksFrom(const std::vector<PDNode*>& others);
bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
return teller_(node);
}
bool IsOp() const { return type_ == Type::kOp; }
bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete;
private:
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
pattern_(pattern),
name_(name),
type_(type) {
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
}
PDNode(PDNode&& other) = default;
friend class PDPattern;
teller_t teller_;
PDPattern* pattern_;
std::string name_;
Type type_;
};
/*
@ -102,6 +117,8 @@ class PDPattern {
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; }
std::string DotString() const;
private:
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST(PDPattern, AddEdge);
@ -117,7 +134,7 @@ class PDPattern {
};
/*
* GraphPatternDetecter helps to detect the specific patterns in the graph.
* GraphPatternDetector helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
*
@ -129,7 +146,7 @@ class PDPattern {
*
* Usage:
* // Create a detector
* GraphPatternDetecter detector;
* GraphPatternDetector detector;
* // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...)
@ -138,11 +155,11 @@ class PDPattern {
* detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns.
* GraphPatternDetecter::handle_t handler = some labmda
* GraphPatternDetector::handle_t handler = some labmda
* // Execute the detector.
* detector(&graph, handler);
*/
class GraphPatternDetecter {
class GraphPatternDetector {
public:
using subgraph_t = std::unordered_map<PDNode*, Node*>;
@ -177,10 +194,62 @@ class GraphPatternDetecter {
using hit_rcd_t =
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
PDPattern pattern_;
std::vector<hit_rcd_t> marked_records_;
std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
};
// some helper methods.
// Op's input.
static bool VarLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Op's output.
static bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Check whether a var node is a op node's nth input.
static bool IsNthInput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <gtest/gtest.h>
@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
}
TEST(GraphPatternDetecter, MarkPDNodesInGraph) {
GraphPatternDetecter x;
GraphPatternDetector x;
// mark o2, o3, v2
// The pattern is a graph:
@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
Graph graph(program);
BuildGraph(&graph);
GraphPatternDetecter x;
GraphPatternDetector x;
// The pattern is a graph:
// op -> var
@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
x.mutable_pattern()->AddEdge(any_var, any_op1);
int count = 0;
GraphPatternDetecter::handle_t handle = [&](
const GraphPatternDetecter::subgraph_t& s, Graph* g) {
GraphPatternDetector::handle_t handle = [&](
const GraphPatternDetector::subgraph_t& s, Graph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> "
<< s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name();
count++;

@ -16,11 +16,13 @@ limitations under the License. */
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
using inference::analysis::Dot;
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout;
size_t var_id = 0;
std::unordered_map<const ir::Node*, size_t> vars;
sout << "digraph G {\n";
for (const ir::Node* n : graph->Nodes()) {
if (n->NodeType() != ir::Node::Type::kVariable) continue;
size_t cur_var_id = var_id++;
vars[n] = cur_var_id;
sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]"
<< std::endl;
}
size_t op_id = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : n->inputs) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
std::unordered_map<const ir::Node*, std::string> node2dot;
Dot dot;
std::vector<Dot::Attr> op_attrs({Dot::Attr("style", "filled"),
Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "red")});
std::vector<Dot::Attr> var_attrs({Dot::Attr("style", "filled,rounded"),
// Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "yellow")});
std::vector<Dot::Attr> marked_op_attrs({Dot::Attr("style", "filled"),
Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "lightgray")});
std::vector<Dot::Attr> marked_var_attrs(
{Dot::Attr("style", "filled,rounded"),
// Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "lightgray")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")";
if (n->IsOp()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_var_attrs : var_attrs;
dot.AddNode(node_id, attr, node_id);
}
for (auto out : n->outputs) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
node2dot[n] = node_id;
}
// Create edges
for (const Node* n : graph->Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& trg_id = node2dot.at(out);
dot.AddEdge(src_id, trg_id, {});
}
}
sout << "}\n";
sout << dot.Build();
return graph;
}
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
Graph* graph) const {
marked_nodes_t res;
if (graph->Has(kGraphvizMarkedNodeAttr)) {
auto& attr = graph->Get<marked_nodes_t>(kGraphvizMarkedNodeAttr);
res = attr;
attr.clear();
}
return res;
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -27,10 +27,19 @@ namespace paddle {
namespace framework {
namespace ir {
const char kGraphvizMarkedNodeAttr[] = "__graphviz__marked_node__";
class GraphVizPass : public Pass {
public:
using marked_nodes_t = std::unordered_set<const Node*>;
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
marked_nodes_t ConsumeMarkedNodes(Graph* graph) const;
};
} // namespace ir

@ -29,20 +29,26 @@ class Node {
enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
explicit Node(const std::string& name, Type type, int id = -1)
: name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(id) {}
explicit Node(VarDesc* var_desc)
explicit Node(VarDesc* var_desc, int id = -1)
: name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr),
type_(Type::kVariable) {}
type_(Type::kVariable),
id_(id) {}
explicit Node(OpDesc* op_desc)
explicit Node(OpDesc* op_desc, int id = -1)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation) {}
type_(Type::kOperation),
id_(id) {}
Type NodeType() const { return type_; }
@ -58,6 +64,8 @@ class Node {
return op_desc_.get();
}
int id() const { return id_; }
bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; }
@ -69,6 +77,7 @@ class Node {
std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_;
Type type_;
int id_;
private:
DISABLE_COPY_AND_ASSIGN(Node);

File diff suppressed because it is too large Load Diff

@ -0,0 +1,33 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConcatFcFusePass : public FusePassBase {
public:
virtual ~SeqConcatFcFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -1,5 +1,8 @@
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
set(analysis_deps
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor)
cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc
helper.cc
# passes
@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc
model_store_pass.cc
DEPS framework_proto proto_desc ir_pass_manager graph pass)
DEPS ${analysis_deps})
cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
endif()
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS}
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING)
@ -58,20 +61,25 @@ endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
analysis_predictor
# ir
fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detecter
infer_clean_graph_pass
graph_pattern_detector
infer_clean_graph_pass
attention_lstm_fuse_pass
paddle_inference_api
pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
--infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)

@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
// Ungly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({
// Manual update the passes here.
"graph_viz_pass", //
"infer_clean_graph_pass", "graph_viz_pass", //
"attention_lstm_fuse_pass", "graph_viz_pass", //
"fc_lstm_fuse_pass", "graph_viz_pass", //
"seq_concat_fc_fuse_pass", "graph_viz_pass", //
"fc_fuse_pass", "graph_viz_pass" //
}));
for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll();

@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
const std::string &data_path, int batch_size,
bool use_analysis, bool activate_ir,
int num_times = 1) {
FLAGS_IA_enable_ir = activate_ir;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
std::string model_out;
if (use_analysis) {
Argument argument(model_path);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
model_out = "./analysis.out";
ASSERT_TRUE(PathExists(model_out));
} else {
model_out = FLAGS_infer_ditu_rnn_model;
}
NativeConfig config;
config.prog_file = model_out + "/__model__";
config.param_file = model_out + "/param";
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
config.use_gpu = false;
config.device = 0;
config.specify_input_name = true;
auto predictor =
auto base_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
std::vector<PaddleTensor> input_slots;
DataRecord data(data_path, batch_size);
// Prepare inputs.
PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs;
std::vector<PaddleTensor> outputs, base_outputs;
base_predictor->Run(input_slots, &base_outputs);
Timer timer;
timer.tic();
@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
<< ", latency: " << timer.toc() / num_times << "ms";
LOG(INFO) << "=====================================";
for (auto &out : outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
1, [](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data());
for (size_t i = 0;
i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size);
i++) {
EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3);
float *base_data = static_cast<float *>(base_out.data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(data[i], base_data[i], 1e-3);
}
}
}
// Turn on the IR pass supportion, run a real inference and check the result.
TEST(Analyzer, SupportIRPass) {
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE(PathExists("./analysis.out"));
// Inference from this path.
TestWord2vecPrediction("./analysis.out");
}
// Directly infer with the original model.
TEST(Analyzer, DituRNN_without_analysis) {
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);

@ -26,6 +26,7 @@
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace inference {
@ -58,6 +59,46 @@ struct Argument {
// The output storage path of ModelStorePass.
std::unique_ptr<std::string> model_output_store_path;
// Support for any other attributes.
template <typename T>
void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key);
attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
VLOG(3) << "argument delete attr: " << key;
delete data;
};
}
bool Has(const std::string& name) const { return attrs_.count(name); }
template <typename T>
T* Release(const std::string& key) {
PADDLE_ENFORCE(attrs_.count(key));
auto* res = boost::any_cast<T*>(attrs_.at(key));
attrs_.erase(key);
attr_deleters_.erase(key);
return res;
}
template <typename T>
T& Get(const std::string& key) {
PADDLE_ENFORCE(Has(key));
return *boost::any_cast<T*>(attrs_.at(key));
}
~Argument() {
for (auto& item : attr_deleters_) {
item.second();
}
}
private:
std::unordered_map<std::string, boost::any> attrs_;
std::unordered_map<std::string, std::function<void()>> attr_deleters_;
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)

@ -19,6 +19,7 @@
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace paddle {
namespace inference {
@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
if (argument_->Has("param_scope")) {
LOG(WARNING) << "parameter changes in the scope takes effect";
}
PADDLE_ENFORCE(argument_->transformed_program_desc.get());
}

@ -29,13 +29,13 @@ namespace paddle {
namespace inference {
namespace analysis {
static size_t dot_node_counter{0};
/*
* A Dot template that helps to build a DOT graph definition.
*/
class Dot {
public:
static size_t counter;
struct Attr {
std::string key;
std::string value;
@ -57,7 +57,7 @@ class Dot {
Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name),
attrs(attrs),
id_("node_" + std::to_string(Dot::counter++)) {}
id_("node_" + std::to_string(dot_node_counter++)) {}
std::string id() const { return id_; }
@ -65,6 +65,10 @@ class Dot {
std::stringstream ss;
CHECK(!name.empty());
ss << id_;
if (attrs.empty()) {
ss << "[label=" << '"' << name << '"' << "]";
return ss.str();
}
for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) {
ss << "[label=" << '"' << name << '"' << " ";
@ -108,9 +112,11 @@ class Dot {
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& name, const std::vector<Attr>& attrs) {
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'";
nodes_.emplace(name, Node{name, attrs});
void AddNode(const std::string& id, const std::vector<Attr>& attrs,
std::string label = "") {
CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
if (label.empty()) label = id;
nodes_.emplace(id, Node{label, attrs});
}
void AddEdge(const std::string& source, const std::string& target,

@ -13,3 +13,47 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace analysis {
void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file) {
PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope);
// Load parameters.
VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir,
prog_file, param_file);
}
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file,
const std::string &param_file) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor executor(place);
PADDLE_ENFORCE(argument_->origin_program_desc.get());
framework::ProgramDesc program(*argument_->origin_program_desc);
if ((!prog_file.empty()) && (!param_file.empty())) {
LOG(INFO) << "load single model file from " << prog_file;
Load(&executor, scope, prog_file, param_file);
} else if (!dir.empty()) {
LOG(INFO) << "load from dir " << dir;
Load(&executor, scope, dir);
} else {
LOG(ERROR) << "failed to load parameters";
return false;
}
return true;
}
} // namespace analysis
} // namespace inference
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save