|
|
|
@ -118,6 +118,7 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
|
|
|
|
|
if (same != std::end(node.inputs)) {
|
|
|
|
|
LinkNodes(to, &node);
|
|
|
|
|
node.Op()->SetInput("X", {to->Name()});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -145,7 +146,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetInput("EltwiseParameter", {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("use_mkldnn", true);
|
|
|
|
@ -155,6 +156,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
patterns::LinkNodes(conv_input, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(conv_filter, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(elementwise_add_x, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|