|
|
|
@ -73,19 +73,19 @@ struct ElementwiseAdd {
|
|
|
|
|
auto elementwise_add_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("elementwise_add");
|
|
|
|
|
|
|
|
|
|
auto y_var = pattern->new_node(y_name())
|
|
|
|
|
auto x_var = pattern->new_node(x_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
y_name());
|
|
|
|
|
x_name());
|
|
|
|
|
|
|
|
|
|
conv_output->assert_is_op_input(op_name(),
|
|
|
|
|
x_name());
|
|
|
|
|
y_name());
|
|
|
|
|
|
|
|
|
|
auto out_var = pattern->new_node(out_name())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
out_name());
|
|
|
|
|
|
|
|
|
|
elementwise_add_op->LinksFrom({y_var, conv_output});
|
|
|
|
|
elementwise_add_op->LinksFrom({x_var, conv_output});
|
|
|
|
|
elementwise_add_op->LinksTo({out_var});
|
|
|
|
|
|
|
|
|
|
return out_var;
|
|
|
|
@ -139,13 +139,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
|
|
|
|
|
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {y->Name()});
|
|
|
|
|
op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("use_mkldnn", true);
|
|
|
|
|
op_desc.SetAttr("fuse_eltwise", true);
|
|
|
|
@ -154,7 +155,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
patterns::LinkNodes(conv_input, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(conv_filter, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, y);
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
|
|
|
|
@ -169,14 +170,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.op_name());
|
|
|
|
|
auto elementwise_add_y = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.y_name());
|
|
|
|
|
auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.x_name());
|
|
|
|
|
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.out_name());
|
|
|
|
|
|
|
|
|
|
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
|
|
|
|
|
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
gpd(graph.get(), handler);
|
|
|
|
|