|
|
|
@ -24,35 +24,6 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
// The function keeps the graph consistent by replacing
|
|
|
|
|
// a node 'from' in the set of inputs nodes
|
|
|
|
|
// of the visited node by a node 'to'.
|
|
|
|
|
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
|
auto from_in_inputs =
|
|
|
|
|
std::find(std::begin(node.inputs), std::end(node.inputs), from);
|
|
|
|
|
|
|
|
|
|
if (from_in_inputs != std::end(node.inputs)) {
|
|
|
|
|
IR_NODE_LINK_TO(to, (&node));
|
|
|
|
|
|
|
|
|
|
auto inputs = node.Op()->Inputs();
|
|
|
|
|
|
|
|
|
|
using input_type = VariableNameMap::value_type;
|
|
|
|
|
|
|
|
|
|
std::for_each(std::begin(inputs), std::end(inputs),
|
|
|
|
|
[from, to, &node](const input_type& i) -> void {
|
|
|
|
|
auto param_names = i.second;
|
|
|
|
|
auto pi = std::find(std::begin(param_names),
|
|
|
|
|
std::end(param_names), from->Name());
|
|
|
|
|
|
|
|
|
|
if (pi != std::end(param_names)) {
|
|
|
|
|
node.Op()->SetInput(i.first, {to->Name()});
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
|
|
|
|
|
auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
|
|
|
|
|
for (auto n : graph->Nodes()) {
|
|
|
|
@ -99,25 +70,12 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
|
|
|
|
|
auto bias_input_names = op.Op()->Inputs();
|
|
|
|
|
auto bias_it = bias_input_names.find(bias_name);
|
|
|
|
|
|
|
|
|
|
if (bias_it != std::end(bias_input_names)) {
|
|
|
|
|
bool has_bias = !bias_it->second.empty();
|
|
|
|
|
|
|
|
|
|
if (has_bias) {
|
|
|
|
|
auto bias_names = bias_it->second;
|
|
|
|
|
auto bias_names_it =
|
|
|
|
|
std::find_if(std::begin(op.inputs), std::end(op.inputs),
|
|
|
|
|
[&bias_names](Node* n) -> bool {
|
|
|
|
|
return n->Name() == bias_names[0];
|
|
|
|
|
});
|
|
|
|
|
return *bias_names_it;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return boost::none;
|
|
|
|
|
template <typename T>
|
|
|
|
|
boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
|
|
|
|
|
if (op.Op()->HasAttr(attr))
|
|
|
|
|
return boost::get<T>(op.Op()->GetAttr(attr));
|
|
|
|
|
else
|
|
|
|
|
return boost::none;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
|
|
|
|
@ -151,40 +109,18 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
|
|
|
|
|
|
|
|
|
|
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
auto fuse_relu = HasAttribute<bool>(*conv_op, "fuse_relu");
|
|
|
|
|
if (fuse_relu && *fuse_relu) return;
|
|
|
|
|
|
|
|
|
|
auto conv_bias = HasBias(*conv_op, "Bias");
|
|
|
|
|
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
|
|
|
|
|
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
|
|
|
|
|
conv_op->Op()->SetAttr("fuse_residual_connection", true);
|
|
|
|
|
|
|
|
|
|
if (conv_bias) {
|
|
|
|
|
op_desc.SetInput("Bias", {(*conv_bias)->Name()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
|
|
|
|
|
op_desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("fuse_residual_connection", true);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {conv_output, elementwise_add_op});
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = graph->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(conv_input, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
|
|
|
|
|
|
|
|
|
if (conv_bias) {
|
|
|
|
|
IR_NODE_LINK_TO((*conv_bias), fused_conv_op);
|
|
|
|
|
}
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
|
|
|
|
|
|
|
|
|
|
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(graph,
|
|
|
|
|
{elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
(*fusion_stats)++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -229,60 +165,33 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
|
|
|
|
|
|
|
|
|
|
Node* projection_node;
|
|
|
|
|
Node* residual_conv_op;
|
|
|
|
|
Node* residual_conv_input;
|
|
|
|
|
Node* residual_conv_filter;
|
|
|
|
|
Node* residual_conv_output;
|
|
|
|
|
|
|
|
|
|
if (IsReachable(graph, conv_x_input, conv_y_output)) {
|
|
|
|
|
projection_node = conv_x_output;
|
|
|
|
|
residual_conv_op = conv_y_op;
|
|
|
|
|
residual_conv_input = conv_y_input;
|
|
|
|
|
residual_conv_filter = conv_y_filter;
|
|
|
|
|
residual_conv_output = conv_y_output;
|
|
|
|
|
} else if (IsReachable(graph, conv_y_input, conv_x_output)) {
|
|
|
|
|
projection_node = conv_y_output;
|
|
|
|
|
residual_conv_op = conv_x_op;
|
|
|
|
|
residual_conv_input = conv_x_input;
|
|
|
|
|
residual_conv_filter = conv_x_filter;
|
|
|
|
|
residual_conv_output = conv_x_output;
|
|
|
|
|
} else {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
auto fuse_relu = HasAttribute<bool>(*residual_conv_op, "fuse_relu");
|
|
|
|
|
if (fuse_relu && *fuse_relu) return;
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {residual_conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {residual_conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {projection_node->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {residual_conv_output->Name()});
|
|
|
|
|
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
|
|
|
|
|
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
|
|
|
|
|
|
|
|
|
|
auto residual_conv_bias = HasBias(*residual_conv_op, "Bias");
|
|
|
|
|
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
|
|
|
|
|
|
|
|
|
|
if (residual_conv_bias) {
|
|
|
|
|
op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) {
|
|
|
|
|
op_desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("fuse_residual_connection", true);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {residual_conv_output, elementwise_add_op});
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = graph->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(residual_conv_input, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(projection_node, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, residual_conv_output);
|
|
|
|
|
|
|
|
|
|
if (residual_conv_bias) {
|
|
|
|
|
IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op);
|
|
|
|
|
}
|
|
|
|
|
IR_NODE_LINK_TO(projection_node, residual_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
|
|
|
|
|
|
|
|
|
|
CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(
|
|
|
|
|
graph, {elementwise_add_out, residual_conv_op, elementwise_add_op});
|
|
|
|
|
(*fusion_stats)++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|