|
|
|
@ -13,11 +13,13 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
|
|
|
|
|
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
|
|
|
|
@ -226,7 +228,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
@ -263,7 +266,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
conv_output);
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
@ -302,7 +306,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
|
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
auto get_node_from_elementwise_add =
|
|
|
|
|
[&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
|
|
|
|
|
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
|
|
|
|
|
.AddCombination(
|
|
|
|
|
paddle::framework::compatible::OpVersionComparatorCombination()
|
|
|
|
|
.EQ("conv2d", 0)
|
|
|
|
|
.LE("conv2d", 1)
|
|
|
|
|
.EQ("elementwise_add", 0));
|
|
|
|
|