|
|
|
@ -22,6 +22,7 @@ namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace patterns {
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
struct Pattern : public PatternBase {
|
|
|
|
|
Pattern(PDPattern* pattern, const std::string& name_scope)
|
|
|
|
|
: PatternBase{pattern, name_scope, ""} {}
|
|
|
|
@ -45,7 +46,8 @@ struct Pattern : public PatternBase {
|
|
|
|
|
return node_pattern()->NewNode(node_name(op_name));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
/*
|
|
|
|
|
struct Conv {
|
|
|
|
|
std::string op_name() const { return "conv2d"; }
|
|
|
|
|
std::string input_name() const { return "Input"; }
|
|
|
|
@ -105,7 +107,8 @@ struct ElementwiseAdd {
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
/*
|
|
|
|
|
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
std::shared_ptr<patterns::Pattern> pattern,
|
|
|
|
|
const std::string& op_name) {
|
|
|
|
@ -116,6 +119,7 @@ Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
void LinkNodes(Node* from, Node* to) {
|
|
|
|
|
from->outputs.push_back(to);
|
|
|
|
@ -172,64 +176,50 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
auto pattern = gpd.mutable_pattern();
|
|
|
|
|
auto pattern_ptr = std::make_shared<patterns::Pattern>(pattern, name_scope_);
|
|
|
|
|
|
|
|
|
|
patterns::Conv conv_pattern;
|
|
|
|
|
auto conv_output = conv_pattern(pattern_ptr)();
|
|
|
|
|
patterns::Conv conv_pattern{pattern, "skip_connections_fusion"};
|
|
|
|
|
auto conv_output = conv_pattern();
|
|
|
|
|
|
|
|
|
|
patterns::ElementwiseAdd elementwise_add_pattern;
|
|
|
|
|
elementwise_add_pattern(pattern_ptr)(conv_output);
|
|
|
|
|
patterns::ElementwiseAdd elementwise_add_pattern{pattern,
|
|
|
|
|
"skip_connections_fusion"};
|
|
|
|
|
elementwise_add_pattern(conv_output);
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input, Node* conv_bias,
|
|
|
|
|
Node* conv_filter, Node* conv_output,
|
|
|
|
|
Node* elementwise_add_x) {
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType(conv_pattern.op_name());
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.bias_name(), {conv_bias->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.residual_data_name(),
|
|
|
|
|
{elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput(conv_pattern.output_name(), {conv_output->Name()});
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Bias", {conv_bias->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("use_mkldnn", true);
|
|
|
|
|
op_desc.SetAttr("fuse_eltwise", true);
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
patterns::LinkNodes(conv_input, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(conv_bias, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(conv_filter, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(elementwise_add_x, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, conv_output);
|
|
|
|
|
};
|
|
|
|
|
IR_NODE_LINK_TO(conv_input, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
|
|
|
|
|
|
|
|
|
auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr,
|
|
|
|
|
fuse_conv](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.op_name());
|
|
|
|
|
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.input_name());
|
|
|
|
|
auto conv_bias = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.bias_name());
|
|
|
|
|
auto conv_filter = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, conv_pattern.filter_name());
|
|
|
|
|
auto conv_output = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, conv_pattern.output_name());
|
|
|
|
|
|
|
|
|
|
auto elementwise_add_op = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, elementwise_add_pattern.op_name());
|
|
|
|
|
auto elementwise_add_x = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, elementwise_add_pattern.x_name());
|
|
|
|
|
auto elementwise_add_out = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, elementwise_add_pattern.out_name());
|
|
|
|
|
|
|
|
|
|
fuse_conv(g, conv_input, conv_bias, conv_filter, conv_output,
|
|
|
|
|
elementwise_add_x);
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
};
|
|
|
|
|