|
|
@ -228,8 +228,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
@ -266,8 +265,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
conv_output);
|
|
|
|
conv_output);
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
@ -306,8 +304,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
@ -351,4 +348,4 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
|
|
|
|
.AddCombination(
|
|
|
|
.AddCombination(
|
|
|
|
paddle::framework::compatible::OpVersionComparatorCombination()
|
|
|
|
paddle::framework::compatible::OpVersionComparatorCombination()
|
|
|
|
.LE("conv2d", 1)
|
|
|
|
.LE("conv2d", 1)
|
|
|
|
.EQ("elementwise_add", 0));
|
|
|
|
.LE("elementwise_add", 1));
|
|
|
|