|
|
|
@ -179,15 +179,15 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input,
|
|
|
|
|
Node* conv_filter,
|
|
|
|
|
Node* conv_output,
|
|
|
|
|
Node* conv_filter, Node* conv_output,
|
|
|
|
|
Node* elementwise_add_x) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType(conv_pattern.op_name());
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.residual_data_name(), {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.residual_data_name(),
|
|
|
|
|
{elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput(conv_pattern.output_name(), {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("use_mkldnn", true);
|
|
|
|
@ -201,8 +201,9 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr, fuse_conv]
|
|
|
|
|
(const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
|
|
|
|
|
auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr,
|
|
|
|
|
fuse_conv](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.op_name());
|
|
|
|
|
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|