|
|
|
@ -54,16 +54,15 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
using graph_ptr = std::unique_ptr<ir::Graph>;
|
|
|
|
|
|
|
|
|
|
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
FusePassBase::Init("conv_elementwise_add_mkldnn_fuse_pass", graph.get());
|
|
|
|
|
FusePassBase::Init(name_scope_, graph.get());
|
|
|
|
|
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
auto pattern = gpd.mutable_pattern();
|
|
|
|
|
|
|
|
|
|
patterns::Conv conv_pattern{pattern, "skip_connections_fusion"};
|
|
|
|
|
patterns::Conv conv_pattern{pattern, name_scope_};
|
|
|
|
|
auto conv_output = conv_pattern();
|
|
|
|
|
|
|
|
|
|
patterns::ElementwiseAdd elementwise_add_pattern{pattern,
|
|
|
|
|
"skip_connections_fusion"};
|
|
|
|
|
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
|
|
|
|
|
elementwise_add_pattern(conv_output);
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|