|
|
|
@ -120,17 +120,18 @@ boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
|
|
|
|
|
return boost::none;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
|
|
|
|
|
get_node_from_elementwise_add_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
|
|
|
|
|
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
|
|
|
|
|
get_node_from_conv_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
|
|
|
|
|
get_node_from_elementwise_add_op)
|
|
|
|
|
: fusion_stats{std::make_shared<int>(0)},
|
|
|
|
|
can_fuse_func{can_fuse_func},
|
|
|
|
|
get_node_from_conv_op{get_node_from_conv_op},
|
|
|
|
|
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
|
|
|
|
|
can_fuse_func{can_fuse_func} {}
|
|
|
|
|
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
|
|
|
|
|
|
|
|
|
|
void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
|
|
|
|
|
void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
|
|
|
|
|
Node* conv_op;
|
|
|
|
|
Node* conv_input;
|
|
|
|
@ -187,6 +188,104 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
|
|
|
|
|
(*fusion_stats)++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
|
|
|
|
|
get_node_from_conv_x_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
|
|
|
|
|
get_node_from_conv_y_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
|
|
|
|
|
get_node_from_elementwise_add_op)
|
|
|
|
|
: fusion_stats{std::make_shared<int>(0)},
|
|
|
|
|
can_fuse_func{can_fuse_func},
|
|
|
|
|
get_node_from_conv_x_op{get_node_from_conv_x_op},
|
|
|
|
|
get_node_from_conv_y_op{get_node_from_conv_y_op},
|
|
|
|
|
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
|
|
|
|
|
|
|
|
|
|
void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
|
|
|
|
|
Node* conv_x_op;
|
|
|
|
|
Node* conv_x_input;
|
|
|
|
|
Node* conv_x_filter;
|
|
|
|
|
Node* conv_x_output;
|
|
|
|
|
|
|
|
|
|
Node* conv_y_op;
|
|
|
|
|
Node* conv_y_input;
|
|
|
|
|
Node* conv_y_filter;
|
|
|
|
|
Node* conv_y_output;
|
|
|
|
|
|
|
|
|
|
Node* elementwise_add_op;
|
|
|
|
|
Node* elementwise_add_out;
|
|
|
|
|
|
|
|
|
|
std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
|
|
|
|
|
get_node_from_conv_x_op(subgraph);
|
|
|
|
|
std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
|
|
|
|
|
get_node_from_conv_y_op(subgraph);
|
|
|
|
|
std::tie(elementwise_add_op, elementwise_add_out) =
|
|
|
|
|
get_node_from_elementwise_add_op(subgraph);
|
|
|
|
|
|
|
|
|
|
if (!can_fuse_func(conv_x_op, elementwise_add_op)) return;
|
|
|
|
|
if (!can_fuse_func(conv_y_op, elementwise_add_op)) return;
|
|
|
|
|
|
|
|
|
|
Node* projection_node;
|
|
|
|
|
Node* residual_conv_op;
|
|
|
|
|
Node* residual_conv_input;
|
|
|
|
|
Node* residual_conv_filter;
|
|
|
|
|
Node* residual_conv_output;
|
|
|
|
|
|
|
|
|
|
if (IsReachable(graph, conv_x_input, conv_y_output)) {
|
|
|
|
|
projection_node = conv_x_output;
|
|
|
|
|
residual_conv_op = conv_y_op;
|
|
|
|
|
residual_conv_input = conv_y_input;
|
|
|
|
|
residual_conv_filter = conv_y_filter;
|
|
|
|
|
residual_conv_output = conv_y_output;
|
|
|
|
|
} else if (IsReachable(graph, conv_y_input, conv_x_output)) {
|
|
|
|
|
projection_node = conv_y_output;
|
|
|
|
|
residual_conv_op = conv_x_op;
|
|
|
|
|
residual_conv_input = conv_x_input;
|
|
|
|
|
residual_conv_filter = conv_x_filter;
|
|
|
|
|
residual_conv_output = conv_x_output;
|
|
|
|
|
} else {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {residual_conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {residual_conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {projection_node->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {residual_conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
auto residual_conv_bias = HasBias(*residual_conv_op, "Bias");
|
|
|
|
|
|
|
|
|
|
if (residual_conv_bias) {
|
|
|
|
|
op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) {
|
|
|
|
|
op_desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("fuse_residual_connection", true);
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = graph->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(residual_conv_input, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(projection_node, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, residual_conv_output);
|
|
|
|
|
|
|
|
|
|
if (residual_conv_bias) {
|
|
|
|
|
IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(
|
|
|
|
|
graph, {elementwise_add_out, residual_conv_op, elementwise_add_op});
|
|
|
|
|
(*fusion_stats)++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<Node*, Node*, Node*, Node*>
|
|
|
|
|
ResidualConnectionMKLDNNFusePass::GetNodesFromConv(
|
|
|
|
|
const patterns::Conv& conv_pattern,
|
|
|
|
@ -233,7 +332,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return ExecuteHandlerOnGraph(
|
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
|
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
|
|
|
|
|
return GetNodesFromConv(conv_pattern, subgraph);
|
|
|
|
@ -270,7 +369,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return ExecuteHandlerOnGraph(
|
|
|
|
|
return ExecuteHandleOnGraph<IdentityFuseHandle>(
|
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
|
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
|
|
|
|
|
return GetNodesFromConv(conv_pattern, subgraph);
|
|
|
|
@ -278,33 +377,54 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
get_node_from_elementwise_add);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph(
|
|
|
|
|
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
|
|
|
|
|
get_node_from_elementwise_add) const {
|
|
|
|
|
ir::Graph* graph;
|
|
|
|
|
int stats;
|
|
|
|
|
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
|
|
|
|
|
const std::string& name_scope,
|
|
|
|
|
const GraphWithStats& graph_with_stats) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
auto pattern = gpd.mutable_pattern();
|
|
|
|
|
|
|
|
|
|
std::tie(graph, stats) = graph_with_stats;
|
|
|
|
|
patterns::Conv conv_x_pattern{pattern, name_scope};
|
|
|
|
|
auto conv_x_output = conv_x_pattern();
|
|
|
|
|
|
|
|
|
|
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
|
|
|
|
|
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
|
|
|
|
|
};
|
|
|
|
|
patterns::Conv conv_y_pattern{pattern, name_scope};
|
|
|
|
|
auto conv_y_output = conv_y_pattern();
|
|
|
|
|
|
|
|
|
|
auto fuse_handler =
|
|
|
|
|
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
|
|
|
|
|
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
|
|
|
|
|
elementwise_add_pattern(conv_x_output, conv_y_output);
|
|
|
|
|
conv_x_output->AsIntermediate();
|
|
|
|
|
conv_y_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
(*gpd)(graph, fuse_handler);
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
return std::make_pair(graph, stats + fuse_handler.get_stats());
|
|
|
|
|
return std::make_tuple(elementwise_add_op, elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
|
|
|
|
|
&gpd, graph_with_stats,
|
|
|
|
|
[this,
|
|
|
|
|
&conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
|
|
|
|
|
return GetNodesFromConv(conv_x_pattern, subgraph);
|
|
|
|
|
},
|
|
|
|
|
[this,
|
|
|
|
|
&conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
|
|
|
|
|
return GetNodesFromConv(conv_y_pattern, subgraph);
|
|
|
|
|
},
|
|
|
|
|
get_node_from_elementwise_add);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
FusePassBase::Init(name_scope_, graph.get());
|
|
|
|
|
|
|
|
|
|
auto fused_graph_with_stats = FuseConvAsY(
|
|
|
|
|
name_scope_, FuseConvAsX(name_scope_, std::make_pair(graph.get(), 0)));
|
|
|
|
|
name_scope_,
|
|
|
|
|
FuseConvAsX(
|
|
|
|
|
name_scope_,
|
|
|
|
|
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
|
|
|
|
|
|
|
|
|
|
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
|
|
|
|
|
AddStatis(fused_graph_with_stats.second);
|
|
|
|
|