|
|
@ -228,20 +228,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_y,
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_y,
|
|
|
|
elementwise_add_out);
|
|
|
|
elementwise_add_out);
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
&gpd, graph_with_stats,
|
|
|
@ -266,20 +265,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
conv_output);
|
|
|
|
conv_output);
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_x,
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_x,
|
|
|
|
elementwise_add_out);
|
|
|
|
elementwise_add_out);
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
&gpd, graph_with_stats,
|
|
|
@ -306,17 +304,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
elementwise_add_pattern);
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_out);
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_out);
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
|
|
|
|
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
&gpd, graph_with_stats,
|
|
|
@ -351,4 +348,4 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
|
|
|
|
.AddCombination(
|
|
|
|
.AddCombination(
|
|
|
|
paddle::framework::compatible::OpVersionComparatorCombination()
|
|
|
|
paddle::framework::compatible::OpVersionComparatorCombination()
|
|
|
|
.LE("conv2d", 1)
|
|
|
|
.LE("conv2d", 1)
|
|
|
|
.EQ("elementwise_add", 0));
|
|
|
|
.LE("elementwise_add", 1));
|
|
|
|