MKLDNN conv + elementwise_add fusion: implementation of patterns refarctored, applied to graph. UTs added
parent
9ce343f868
commit
604bad08bc
@ -0,0 +1,178 @@
|
||||
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace patterns {
|
||||
|
||||
struct Pattern : public PatternBase {
|
||||
Pattern(PDPattern* pattern, const std::string& name_scope)
|
||||
: PatternBase{pattern, name_scope, ""}
|
||||
{ }
|
||||
|
||||
private:
|
||||
std::string name_scope() { return name_scope_; }
|
||||
std::string repr() { return repr_; }
|
||||
size_t id() { return id_; }
|
||||
PDPattern* node_pattern() { return pattern; }
|
||||
|
||||
public:
|
||||
std::string node_name(std::string op_name)
|
||||
{
|
||||
return PDNodeName(name_scope(), repr(), id(), op_name);
|
||||
}
|
||||
|
||||
PDNode* retrieve_node(std::string op_name)
|
||||
{
|
||||
return node_pattern()->RetrieveNode(node_name(op_name));
|
||||
}
|
||||
|
||||
PDNode* new_node(std::string op_name)
|
||||
{
|
||||
return node_pattern()->NewNode(node_name(op_name));
|
||||
}
|
||||
};
|
||||
|
||||
struct Conv {
|
||||
std::string conv_name() { return "conv2d"; }
|
||||
std::string input_name() { return "Input"; }
|
||||
std::string filter_name() { return "Filter"; }
|
||||
std::string output_name() { return "Output"; }
|
||||
|
||||
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) {
|
||||
return [&]() -> PDNode* {
|
||||
auto conv_op = pattern->new_node(conv_name())
|
||||
->assert_is_op("conv2d");
|
||||
|
||||
auto input_var = pattern->new_node(input_name())
|
||||
->AsInput()
|
||||
->assert_is_op_input(conv_name());
|
||||
|
||||
auto filter_var = pattern->new_node(filter_name())
|
||||
->AsInput()
|
||||
->assert_is_persistable_var()
|
||||
->assert_is_op_input(conv_name());
|
||||
|
||||
auto output_var = pattern->new_node(output_name())
|
||||
->AsOutput()
|
||||
->assert_is_op_output(conv_name());
|
||||
|
||||
conv_op->LinksFrom({input_var, filter_var});
|
||||
conv_op->LinksTo({output_var});
|
||||
|
||||
return output_var;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ElementwiseAdd {
|
||||
std::string elementwise_add_name() { return "elementwise_add"; }
|
||||
std::string x_name() { return "X"; }
|
||||
std::string y_name() { return "Y"; }
|
||||
std::string out_name() { return "Out"; }
|
||||
|
||||
std::function<PDNode* (PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
|
||||
return [&](PDNode* conv_output) -> PDNode* {
|
||||
auto elementwise_add_op = pattern->new_node(elementwise_add_name())
|
||||
->assert_is_op("elementwise_add");
|
||||
|
||||
auto y_var = pattern->new_node(y_name())
|
||||
->AsInput()
|
||||
->assert_is_op_input(elementwise_add_name());
|
||||
|
||||
conv_output->assert_is_op_input(pattern->node_name(elementwise_add_name()),
|
||||
pattern->node_name(x_name()));
|
||||
// auto y_var = pattern->NewNode(y_name())
|
||||
// ->AsInput()
|
||||
// ->assert_is_op_input(elementwise_add_name());
|
||||
|
||||
auto out_var = pattern->new_node(out_name())
|
||||
->AsOutput()
|
||||
->assert_is_op_output(
|
||||
pattern->node_name(elementwise_add_name()));
|
||||
|
||||
elementwise_add_op->LinksFrom({y_var, conv_output});
|
||||
elementwise_add_op->LinksTo({out_var});
|
||||
|
||||
return out_var;
|
||||
};
|
||||
}
|
||||
};
|
||||
} // namespace patterns
|
||||
|
||||
Node* node_from_subgraph(const GraphPatternDetector::subgraph_t& subgraph,
|
||||
std::shared_ptr<patterns::Pattern> pattern, const std::string& op_name)
|
||||
{
|
||||
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
|
||||
"Node not found for PDNode %s", pattern->node_name(op_name));
|
||||
Node* var = subgraph.at(pattern->retrieve_node(op_name));
|
||||
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph");
|
||||
|
||||
return var;
|
||||
}
|
||||
|
||||
using graph_ptr = std::unique_ptr<ir::Graph>;
|
||||
|
||||
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
||||
FusePassBase::Init("conv_elementwise_add_mkldnn_fuse_pass", graph.get());
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto pattern = gpd.mutable_pattern();
|
||||
|
||||
auto pattern_ptr = std::make_shared<patterns::Pattern>(pattern, name_scope_);
|
||||
|
||||
patterns::Conv conv_pattern;
|
||||
auto conv_output = conv_pattern(pattern_ptr)();
|
||||
conv_output->AsIntermediate();
|
||||
|
||||
patterns::ElementwiseAdd elementwise_add_pattern;
|
||||
elementwise_add_pattern(pattern_ptr)(conv_output);
|
||||
|
||||
auto link_nodes_to = [](Node* a, Node* b) {
|
||||
a->outputs.push_back(b);
|
||||
b->inputs.push_back(a);
|
||||
};
|
||||
|
||||
auto fuse_conv = [&](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("conv2d");
|
||||
|
||||
op_desc.SetInput("Input", {conv_input->Name()});
|
||||
op_desc.SetInput("Filter", {conv_filter->Name()});
|
||||
op_desc.SetOutput("Ouput", {y->Name()});
|
||||
|
||||
op_desc.SetAttr("fuse_sum", true);
|
||||
|
||||
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
||||
|
||||
link_nodes_to(conv_input, fused_conv_op);
|
||||
link_nodes_to(conv_filter, fused_conv_op);
|
||||
link_nodes_to(fused_conv_op, y);
|
||||
};
|
||||
|
||||
auto remove_unused_nodes = [](Graph* g, const std::unordered_set<const Node*>& removed_nodes) {
|
||||
GraphSafeRemoveNodes(g, removed_nodes);
|
||||
};
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
|
||||
auto elementwise_add_x = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.x_name());
|
||||
auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
|
||||
auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
|
||||
|
||||
auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
|
||||
auto conv_input = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.input_name());
|
||||
auto conv_output = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.output_name());
|
||||
|
||||
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
|
||||
remove_unused_nodes(g, {elementwise_add_x, conv_output, elementwise_add_out});
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
|
||||
return graph;
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
|
@ -0,0 +1,81 @@
|
||||
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs) {
|
||||
auto op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
|
||||
if (type == "conv2d") {
|
||||
op->SetAttr("use_mkldnn", true);
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetInput("Filter", {inputs[1]});
|
||||
op->SetInput("Output", {outputs});
|
||||
} else if (type == "elementwise_add") {
|
||||
op->SetInput("X", {inputs[0]});
|
||||
op->SetInput("Y", {inputs[1]});
|
||||
op->SetOutput("Out", outputs);
|
||||
}
|
||||
}
|
||||
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto& v :
|
||||
std::vector<std::string>({"a", "b", "c", "d", "weights", "f", "g"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
if (v == "weights" || v == "bias") {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "OP0", {"a"}, {"b"});
|
||||
SetOp(&prog, "OP1", {"c"}, {"d"});
|
||||
SetOp(&prog, "conv2d", {"d", "weights"}, {"f"});
|
||||
SetOp(&prog, "elemenwise_add", {"d", "f"}, {"g"});
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, basic) {
|
||||
auto prog = BuildProgramDesc();
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
|
||||
// Assert conv_relu op in newly generated graph
|
||||
int conv_elementwise_add_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "conv2d") {
|
||||
if (node->Op()->HasAttr("use_mkldnn")) {
|
||||
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
|
||||
if (use_mkldnn) {
|
||||
// TODO tpatejko: it is commented because convolution does not support this attribute
|
||||
if (true/*node->Op()->HasAttr("fuse_sum")*/) {
|
||||
// bool fuse_sum = boost::get<bool>(node->Op()->GetAttr("fuse_sum"));
|
||||
if (true /*fuse_sum*/) {
|
||||
++conv_elementwise_add_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(conv_elementwise_add_count, 1);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
|
@ -1,174 +0,0 @@
|
||||
#include "paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace patterns {
|
||||
|
||||
struct PatternNode {
|
||||
PatternNode(PDPattern* pattern,
|
||||
const std::string& name,
|
||||
const std::string& name_scope,
|
||||
const std::string& repr,
|
||||
size_t id)
|
||||
: nodeName{PDNodeName(name_scope, repr, id, name)}
|
||||
, node{pattern->RetrieveNode(nodeName)
|
||||
{ }
|
||||
|
||||
std::string name() { return nodeName };
|
||||
PDNode* node() { return node };
|
||||
|
||||
private:
|
||||
std::string nodeName;
|
||||
PDNode* node;
|
||||
};
|
||||
/*
|
||||
|
||||
struct Conv : public PatternBase {
|
||||
Conv(PDPattern* pattern, const std::string& name_scope)
|
||||
: PatternBase{pattern, name_scope, "conv"}
|
||||
, conv{pattern, "conv", name_scope_, repr_, id_}
|
||||
, input{pattern, "Input", name_scope_, repr_, id_}
|
||||
, filter{pattern, "Filter", name_scope_, repr_, id_}
|
||||
, output{pattern, "Output", node_scope_, repr_ id_}
|
||||
{ }
|
||||
|
||||
private:
|
||||
PatternNode conv;
|
||||
PatternNode input;
|
||||
PatternNode filter;
|
||||
PatternNode output;
|
||||
|
||||
public:
|
||||
PDNode* operator()() {
|
||||
auto conv_op = pattern->NewNode(conv.name())
|
||||
->assert_is_op("conv2d");
|
||||
|
||||
auto input_var = pattern->NewNode(input.name())
|
||||
->AsInput()
|
||||
->assert_is_op_input(conv.name());
|
||||
|
||||
auto filter_var = pattern->NewNode(filter.name())
|
||||
->AsInput()
|
||||
->assert_is_persistable_var()
|
||||
->assert_is_op_input(conv.name());
|
||||
|
||||
auto output_var = patterh->NewNode(output.name())
|
||||
->AsOutput()
|
||||
->assert_is_op_output(conv.name());
|
||||
|
||||
conv_op->LinksFrom({input_var, filter_var});
|
||||
conv_op->LinksTo({output_var};
|
||||
|
||||
return output_var;
|
||||
}
|
||||
};
|
||||
*/
|
||||
|
||||
struct Conv : public PatternBase {
|
||||
Conv(PDPattern* pattern, const std::string& name_scope)
|
||||
: PatternBase{pattern, name_scope, "conv"}
|
||||
{ }
|
||||
|
||||
std::string conv_name() { return PDNodeName(name_scope_, repr_, id_, "conv2d"); }
|
||||
PDNode* conv_node() { return pattern->RetrieveNode(conv_name()); }
|
||||
|
||||
std::string input_name() { return PDNodeName(name_scope, repr_, id_, "Input"); }
|
||||
PDNode* input_node() { return pattern->RetrieveNode(input_name()); }
|
||||
|
||||
std::string filter_name() { return PDNodeName(name_scope_, repr_, id_, "Filter"); }
|
||||
PDNode* filter_node() { return pattern->RetrieveNode(filter_name()); }
|
||||
|
||||
std::string output_name() { return PDNodeName(name_scope, repr_, id_, "Output"); }
|
||||
PDNode* output_node() { return pattern->RetrieveNode(output_name()); }
|
||||
|
||||
PDNode* operator()() {
|
||||
auto conv_op = pattern->NewNode(conv_name())
|
||||
->assert_is_op("conv2d");
|
||||
|
||||
auto input_var = pattern->NewNode(input_name())
|
||||
->AsInput()
|
||||
->assert_is_op_input(conv_name());
|
||||
|
||||
auto filter_var = pattern->NewNode(filter_name())
|
||||
->AsInput()
|
||||
->assert_is_persistable_var()
|
||||
->assert_is_op_input(conv_name());
|
||||
|
||||
auto output_var = patterh->NewNode(output_name())
|
||||
->AsOutput()
|
||||
->assert_is_op_output(conv_name());
|
||||
|
||||
conv_op->LinksFrom({input_var, filter_var});
|
||||
conv_op->LinksTo({output_var};
|
||||
|
||||
return output_var;
|
||||
}
|
||||
};
|
||||
|
||||
struct ElementwiseAdd : public PatternBase {
|
||||
Conv(PDPattern* pattern, const std::string& name_scope)
|
||||
: PatternBase{pattern, name_scope, "elementwise_add"}
|
||||
{ }
|
||||
|
||||
std::string elementwise_add_name() { return PDNodeName(name_scope_, repr_, id_, "elementwise_add"); }
|
||||
PDNode* elementwise_add_node() { return pattern->RetrieveNode(elementwise_add_name()); }
|
||||
|
||||
std::string x_name() { return PDNodeName(name_scope, repr_, id_, "X"); }
|
||||
PDNode* x_node() { return pattern->RetrieveNode(x_name()); }
|
||||
|
||||
std::string y_name() { return PDNodeName(name_scope_, repr_, id_, "Y"); }
|
||||
PDNode* y_node() { return pattern->RetrieveNode(y_name()); }
|
||||
|
||||
std::string out_name() { return PDNodeName(name_scope, repr_, id_, "Out"); }
|
||||
PDNode* out_node() { return pattern->RetrieveNode(out_name()); }
|
||||
|
||||
PDNode* operator()(PDNode* conv_output) {
|
||||
auto elementwise_add_op = pattern->NewNode(conv_name())
|
||||
->assert_is_op("elementwise_add");
|
||||
|
||||
auto x_var = pattern->NewNode(x_name())
|
||||
->AsInput()
|
||||
->assert_is_op_input(elementwise_add_name());
|
||||
|
||||
conv_output->assert_is_op_input(elementwise_add_name(), y_name());
|
||||
// auto y_var = pattern->NewNode(y_name())
|
||||
// ->AsInput()
|
||||
// ->assert_is_op_input(elementwise_add_name());
|
||||
|
||||
auto out_var = pattern->NewNode(out_name())
|
||||
->AsOutput()
|
||||
->assert_is_op_output(elementwise_add_name());
|
||||
|
||||
conv_op->LinksFrom({x_var, conv_output});
|
||||
conv_op->LinksTo({out_var};
|
||||
|
||||
return out_var;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace patterns
|
||||
|
||||
using graph_ptr = std::unique_ptr<ir::Graph>;
|
||||
|
||||
graph_ptr MKLDNNConvElementwiseAddFusePass::ApplyImpl(graph_ptr) const {
|
||||
FusePassBase::Init("mkldnn_conv_elementwise_add_fuse", graph.get());
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto pattern = gpd.mutable_pattern();
|
||||
|
||||
patterns::Conv conv_pattern(pattern, name_scope_);
|
||||
auto conv_output = conv_pattern();
|
||||
conv_output->AsIntermediate();
|
||||
|
||||
patterns::ElementwiseAdd elementwise_add_pattern(pattern, name_scope_);
|
||||
auto elementwis_add_output = elementwise_add_pattern(conv_output);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue