|
|
|
@ -39,25 +39,25 @@ struct Conv {
|
|
|
|
|
|
|
|
|
|
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&]() -> PDNode* {
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("conv2d");
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("conv2d");
|
|
|
|
|
|
|
|
|
|
auto input_var = pattern->new_node(input_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
input_name());
|
|
|
|
|
|
|
|
|
|
auto filter_var = pattern->new_node(filter_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
filter_name());
|
|
|
|
|
auto input_var = pattern->new_node(input_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
input_name());
|
|
|
|
|
|
|
|
|
|
auto filter_var = pattern->new_node(filter_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
filter_name());
|
|
|
|
|
|
|
|
|
|
auto output_var = pattern->new_node(output_name())
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
output_name());
|
|
|
|
|
auto output_var = pattern->new_node(output_name())
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
output_name());
|
|
|
|
|
|
|
|
|
|
conv_op->LinksFrom({input_var, filter_var});
|
|
|
|
|
conv_op->LinksTo({output_var});
|
|
|
|
|
conv_op->LinksFrom({input_var, filter_var});
|
|
|
|
|
conv_op->LinksTo({output_var});
|
|
|
|
|
|
|
|
|
|
return output_var;
|
|
|
|
|
return output_var;
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -139,7 +139,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [&](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
|
|
|
|
|
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
@ -147,7 +147,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {y->Name()});
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("fuse_sum", true);
|
|
|
|
|
op_desc.SetAttr("use_mkldnn", true);
|
|
|
|
|
op_desc.SetAttr("fuse_eltwise", true);
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
@ -175,7 +176,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
|
|
|
|
|
patterns::GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
gpd(graph.get(), handler);
|
|
|
|
|