|
|
|
@ -40,18 +40,20 @@ framework::proto::OpDesc PrepareOpDesc(
|
|
|
|
|
const std::string& output) {
|
|
|
|
|
auto proto = base_desc;
|
|
|
|
|
framework::OpDesc desc(proto, nullptr);
|
|
|
|
|
desc.SetType("conv2d_fusion");
|
|
|
|
|
desc.SetInput("Bias", {bias});
|
|
|
|
|
desc.SetInput("ResidualData", {bias1});
|
|
|
|
|
desc.SetAttr("activation", activation);
|
|
|
|
|
desc.SetOutput("Output", {output});
|
|
|
|
|
desc.SetAttr("is_test", true);
|
|
|
|
|
|
|
|
|
|
desc.SetAttr("use_cudnn", false);
|
|
|
|
|
desc.Flush();
|
|
|
|
|
return *desc.Proto();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph) const {
|
|
|
|
|
const std::string pattern_name = "conv_elementwise_add_act_fuse";
|
|
|
|
|
const std::string pattern_name = "conv_elementwise_add2_act_fuse";
|
|
|
|
|
FusePassBase::Init(pattern_name, graph.get());
|
|
|
|
|
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
@ -76,22 +78,23 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
|
|
|
|
|
framework::OpDesc new_op_desc(new_op_proto, nullptr);
|
|
|
|
|
|
|
|
|
|
// Create a new node for the fused op.
|
|
|
|
|
graph->CreateOpNode(&new_op_desc);
|
|
|
|
|
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
|
|
|
|
|
|
|
|
|
|
// Link inputs and outputs.
|
|
|
|
|
PADDLE_ENFORCE(subgraph.count(x));
|
|
|
|
|
auto* conv_in_node = subgraph.at(x);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(conv_in_node, conv_op); // Input
|
|
|
|
|
IR_NODE_LINK_TO(conv_filter, conv_op); // Filter
|
|
|
|
|
IR_NODE_LINK_TO(conv_op, conv_out); // Output
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_in_y, conv_op); // Bias
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_in_y_1, conv_op); // Bias
|
|
|
|
|
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
|
|
|
|
|
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_in_y_1, new_conv_op); // Bias
|
|
|
|
|
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
|
|
|
|
|
|
|
|
|
|
// Delete the unneeded nodes.
|
|
|
|
|
GraphSafeRemoveNodes(graph.get(),
|
|
|
|
|
{conv_op, elementwise_add_op, elementwise_add_op_1,
|
|
|
|
|
elementwise_add_out});
|
|
|
|
|
GraphSafeRemoveNodes(
|
|
|
|
|
graph.get(),
|
|
|
|
|
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
|
|
|
|
|
elementwise_add_out, elementwise_add_out_1, act_op});
|
|
|
|
|
};
|
|
|
|
|
gpd(graph.get(), handler);
|
|
|
|
|
return graph;
|
|
|
|
|