Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into disable_prelu_test_local
commit
efe88ab9f9
@ -0,0 +1,192 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
bool VarOutLinksToOp(Node* node, const std::string& op_type) {
|
||||
for (auto* out : node->outputs) {
|
||||
if (out->IsOp() && out->Op()->Type() == op_type) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void BuildFCPattern(PDPattern* pattern) {
|
||||
// make sure the selected MUL op has one input argument is a parameter.
|
||||
auto* mul_parameter_var = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsVar() && node->outputs.size() == 1UL &&
|
||||
node->outputs.front()->Op()->Type() == "mul" && node->Var() &&
|
||||
node->Var()->Persistable(); // check is a parameter
|
||||
},
|
||||
"mul_weight" /*name*/);
|
||||
|
||||
auto* mul_tmp_input_var = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
bool result =
|
||||
node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
|
||||
!node->Var()->Persistable(); // this input is not an parameter.
|
||||
if (!result) return false;
|
||||
// check whether one output is MUL op.
|
||||
for (auto* op : node->outputs) {
|
||||
if (op->IsOp() && op->Op()->Type() == "mul") return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
"mul_tmp_var" /*name*/);
|
||||
|
||||
// select a MUL op
|
||||
auto* mul_op = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsOp() && // start from an Op
|
||||
node->Op()->Type() == "mul"; // type is mul
|
||||
// the output should be consumed only by one element_add, that check
|
||||
// leaves in a Var PDNode.
|
||||
},
|
||||
"mul" /*name*/);
|
||||
|
||||
// make sure the MUL op's output has only one consumer and links to an
|
||||
// ELEMENTWISE_ADD op.
|
||||
auto* mul_out_var = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsVar() && // starts from a Var
|
||||
node->outputs.size() == 1UL && // only has one consumer
|
||||
node->outputs.front()->IsOp() && // check basic logic
|
||||
node->Var() && // not a ControlDepVar
|
||||
node->outputs.front()->Op()->Type() ==
|
||||
"elementwise_add"; // a very strong validation
|
||||
},
|
||||
"mul_out");
|
||||
// this check is not essential, just to make the corresponding variable Node
|
||||
// retrival easier.
|
||||
auto* elementwise_add_tmp_var = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
|
||||
VarOutLinksToOp(node, "elementwise_add");
|
||||
},
|
||||
"elementwise_add_tmpvar");
|
||||
|
||||
// select an ELEMENTWISE_ADD op
|
||||
auto* elementwise_add_op = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsOp() && node->Op()->Type() == "elementwise_add";
|
||||
},
|
||||
"elementwise_add" /*name*/);
|
||||
|
||||
// get the ELEMENTWISE_ADD op's output
|
||||
auto* elementwise_add_out_var = pattern->NewNode(
|
||||
[](Node* node) {
|
||||
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
|
||||
node->inputs.front()->Op()->Type() == "elementwise_add";
|
||||
},
|
||||
"elementwise_add_out");
|
||||
|
||||
pattern->AddEdge(mul_parameter_var, mul_op);
|
||||
pattern->AddEdge(mul_tmp_input_var, mul_op);
|
||||
pattern->AddEdge(mul_op, mul_out_var);
|
||||
pattern->AddEdge(mul_out_var, elementwise_add_op);
|
||||
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
|
||||
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
|
||||
}
|
||||
|
||||
// Replace the node `from` in the links to `to`
|
||||
bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
|
||||
for (auto*& n : *links) {
|
||||
if (n == from) {
|
||||
n = to;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
|
||||
std::unordered_set<Node*> nodes2delete;
|
||||
|
||||
GraphPatternDetecter gpd;
|
||||
BuildFCPattern(gpd.mutable_pattern());
|
||||
|
||||
#define GET_NODE(id) \
|
||||
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
|
||||
"pattern has no Node called %s", #id); \
|
||||
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
|
||||
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
|
||||
|
||||
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle FC fuse";
|
||||
// Currently, there is no FC op available, so I will just simulate the
|
||||
// scenerio.
|
||||
// FC's fusion is simple, just op fuse, no need to process the
|
||||
// parameters.
|
||||
GET_NODE(mul_tmp_var); // x
|
||||
GET_NODE(mul_weight); // Y
|
||||
GET_NODE(elementwise_add_tmpvar); // bias
|
||||
GET_NODE(elementwise_add_out); // Out
|
||||
GET_NODE(mul); // MUL op
|
||||
GET_NODE(elementwise_add); // ELEMENT_ADD op
|
||||
GET_NODE(mul_out); // tmp
|
||||
#undef GET_NODE
|
||||
|
||||
// Create an FC Node.
|
||||
OpDesc desc;
|
||||
std::string fc_x_in = mul_tmp_var->Name();
|
||||
std::string fc_Y_in = mul_weight->Name();
|
||||
std::string fc_bias_in = elementwise_add_tmpvar->Name();
|
||||
std::string fc_out = elementwise_add_out->Name();
|
||||
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
|
||||
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
|
||||
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
|
||||
desc.SetOutput("Out", std::vector<std::string>({fc_out}));
|
||||
desc.SetType("fc");
|
||||
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
|
||||
fc_node->inputs =
|
||||
std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
|
||||
fc_node->outputs.push_back(elementwise_add_out);
|
||||
|
||||
// Update link relatons
|
||||
PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
|
||||
PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
|
||||
PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
|
||||
elementwise_add, fc_node));
|
||||
PADDLE_ENFORCE(
|
||||
LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
|
||||
|
||||
// Drop old nodes
|
||||
graph->RemoveNode(mul);
|
||||
graph->RemoveNode(elementwise_add);
|
||||
graph->RemoveNode(mul_out); // tmp variable
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass);
|
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
|
||||
*/
|
||||
class FCFusePass : public Pass {
|
||||
public:
|
||||
virtual ~FCFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetInput("Xs", inputs);
|
||||
op->SetOutput("Ys", outputs);
|
||||
}
|
||||
|
||||
// a->OP0->b
|
||||
// a->OP1->c
|
||||
// (b, c)->mul->d
|
||||
// (d, e)->elementwise_add->f
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::SELECTED_ROWS);
|
||||
if (v == "c") {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
|
||||
std::vector<std::string>({"b"}));
|
||||
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
|
||||
std::vector<std::string>({"c"}));
|
||||
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
|
||||
std::vector<std::string>({"d"}));
|
||||
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
|
||||
std::vector<std::string>({"f"}));
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
TEST(FCFusePass, basic) {
|
||||
auto prog = BuildProgramDesc();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
|
||||
|
||||
int pre_nodes = graph->Nodes().size();
|
||||
|
||||
graph = pass->Apply(std::move(graph));
|
||||
|
||||
int after_nodes = graph->Nodes().size();
|
||||
|
||||
// Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out
|
||||
// Add 1 Node: FC
|
||||
EXPECT_EQ(pre_nodes - 2, after_nodes);
|
||||
|
||||
// Assert fc op in newly generated graph
|
||||
int fc_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "fc") {
|
||||
++fc_count;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(fc_count, 1);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(fc_fuse_pass);
|
@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class InferCleanGraphPass : public Pass {
|
||||
public:
|
||||
virtual ~InferCleanGraphPass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
|
||||
auto is_valid_node = [](Node* x) {
|
||||
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
|
||||
};
|
||||
|
||||
std::unordered_set<Node*> invalid_nodes;
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (is_valid_node(node)) {
|
||||
invalid_nodes.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
// remove nodes from the graph.
|
||||
for (auto* node : invalid_nodes) {
|
||||
graph->RemoveNode(node);
|
||||
}
|
||||
|
||||
// clean edges.
|
||||
for (auto* node : graph->Nodes()) {
|
||||
CleanEdges(&node->inputs, invalid_nodes);
|
||||
CleanEdges(&node->outputs, invalid_nodes);
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
void CleanEdges(std::vector<Node*>* nodes,
|
||||
const std::unordered_set<Node*>& to_remove) const {
|
||||
auto it = std::remove_if(nodes->begin(), nodes->end(),
|
||||
[&](Node* x) { return to_remove.count(x); });
|
||||
nodes->erase(it, nodes->end());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(infer_clean_graph_pass,
|
||||
paddle::framework::ir::InferCleanGraphPass);
|
@ -0,0 +1,15 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue