|
|
|
@ -13,11 +13,13 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
|
|
|
|
|
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
|
|
|
|
@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_y,
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_y,
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
conv_output);
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_x,
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_x,
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
|
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
|
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
|
|
|
|
|
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
|
|
|
|
|
.AddCombination(
|
|
|
|
|
paddle::framework::compatible::OpVersionComparatorCombination()
|
|
|
|
|
.EQ("conv2d", 0)
|
|
|
|
|
.LE("conv2d", 1)
|
|
|
|
|
.EQ("elementwise_add", 0));
|
|
|
|
|